Compare commits

..

2 Commits

Author SHA1 Message Date
pablodanswer
25b38212e9 nit 2025-01-19 09:50:35 -08:00
pablodanswer
3096b0b2a7 add linear check 2025-01-19 09:49:26 -08:00
271 changed files with 6451 additions and 9578 deletions

View File

@@ -11,4 +11,5 @@
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
- [ ] I have included a link to a Linear ticket in my description.
- [ ] [Optional] Override Linear Check

View File

@@ -67,7 +67,6 @@ jobs:
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NODE_OPTIONS=--max-old-space-size=8192
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -60,8 +60,6 @@ jobs:
push: true
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -9,9 +9,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check PR body for Linear link or override
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
PR_BODY="${{ github.event.pull_request.body }}"
# Looking for "https://linear.app" in the body
if echo "$PR_BODY" | grep -qE "https://linear\.app"; then
echo "Found a Linear link. Check passed."

View File

@@ -39,12 +39,6 @@ env:
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
# Sharepoint
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/

View File

@@ -119,7 +119,7 @@ There are two editions of Onyx:
- Whitelabeling
- API key authentication
- Encryption of secrets
- And many more! Checkout [our website](https://www.onyx.app/) for the latest.
- Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
To try the Onyx Enterprise Edition:

View File

@@ -9,10 +9,8 @@ founders@onyx.app for more information. Please visit https://github.com/onyx-dot
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.8-dev
# DO_NOT_TRACK is used to disable telemetry for Unstructured
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true" \
DO_NOT_TRACK="true"
DANSWER_RUNNING_IN_DOCKER="true"
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"

View File

@@ -1,37 +0,0 @@
"""lowercase_user_emails
Revision ID: 4d58345da04a
Revises: f1ca58b2f2ec
Create Date: 2025-01-29 07:48:46.784041
"""
from alembic import op
from sqlalchemy.sql import text
# revision identifiers, used by Alembic.
revision = "4d58345da04a"
down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get database connection
connection = op.get_bind()
# Update all user emails to lowercase
connection.execute(
text(
"""
UPDATE "user"
SET email = LOWER(email)
WHERE email != LOWER(email)
"""
)
)
def downgrade() -> None:
# Cannot restore original case of emails
pass

View File

@@ -1,29 +0,0 @@
"""remove recent assistants
Revision ID: a6df6b88ef81
Revises: 4d58345da04a
Create Date: 2025-01-29 10:25:52.790407
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a6df6b88ef81"
down_revision = "4d58345da04a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_column("user", "recent_assistants")
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
),
)

View File

@@ -1,33 +0,0 @@
"""add passthrough auth to tool
Revision ID: f1ca58b2f2ec
Revises: c7bf5721733e
Create Date: 2024-03-19
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "f1ca58b2f2ec"
down_revision: Union[str, None] = "c7bf5721733e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add passthrough_auth column to tool table with default value of False
op.add_column(
"tool",
sa.Column(
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
),
)
def downgrade() -> None:
# Remove passthrough_auth column from tool table
op.drop_column("tool", "passthrough_auth")

View File

@@ -32,7 +32,6 @@ def perform_ttl_management_task(
@celery_app.task(
name="check_ttl_management_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(*, tenant_id: str | None) -> None:
@@ -57,7 +56,6 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
@celery_app.task(
name="autogenerate_usage_report_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:

View File

@@ -1,72 +1,30 @@
from datetime import timedelta
from typing import Any
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.beat_schedule import (
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
)
from onyx.background.celery.tasks.beat_schedule import (
tasks_to_schedule as base_tasks_to_schedule,
)
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from shared_configs.configs import MULTI_TENANT
ee_cloud_tasks_to_schedule = [
ee_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
},
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
},
},
]
ee_tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
ee_tasks_to_schedule = [
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule
return base_cloud_tasks_to_schedule
def get_tasks_to_schedule() -> list[dict[str, Any]]:

View File

@@ -4,20 +4,6 @@ import os
# Applicable for OIDC Auth
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
# Applicable for OIDC Auth, allows you to override the scopes that
# are requested from the OIDC provider. Currently used when passing
# over access tokens to tool calls and the tool needs more scopes
OIDC_SCOPE_OVERRIDE: list[str] | None = None
_OIDC_SCOPE_OVERRIDE = os.environ.get("OIDC_SCOPE_OVERRIDE")
if _OIDC_SCOPE_OVERRIDE:
try:
OIDC_SCOPE_OVERRIDE = [
scope.strip() for scope in _OIDC_SCOPE_OVERRIDE.split(",")
]
except Exception:
pass
# Applicable for SAML Auth
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_config"

View File

@@ -98,9 +98,10 @@ def get_page_of_chat_sessions(
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
subquery = (
select(ChatSession.id)
select(ChatSession.id, ChatSession.time_created)
.filter(*conditions)
.order_by(desc(ChatSession.time_created), ChatSession.id)
.order_by(ChatSession.id, desc(ChatSession.time_created))
.distinct(ChatSession.id)
.limit(page_size)
.offset(page_num * page_size)
.subquery()
@@ -117,11 +118,7 @@ def get_page_of_chat_sessions(
ChatMessage.chat_message_feedbacks
),
)
.order_by(
desc(ChatSession.time_created),
ChatSession.id,
asc(ChatMessage.id), # Ensure chronological message order
)
.order_by(desc(ChatSession.time_created), asc(ChatMessage.id))
)
return db_session.scalars(stmt).unique().all()

View File

@@ -42,22 +42,24 @@ def _fetch_permissions_for_permission_ids(
if not permission_info or not doc_id:
return []
# Check cache first for all permission IDs
permissions = [
_PERMISSION_ID_PERMISSION_MAP[pid]
for pid in permission_ids
if pid in _PERMISSION_ID_PERMISSION_MAP
]
# If we found all permissions in cache, return them
if len(permissions) == len(permission_ids):
return permissions
owner_email = permission_info.get("owner_email")
drive_service = get_drive_service(
creds=google_drive_connector.creds,
user_email=(owner_email or google_drive_connector.primary_admin_email),
)
# Otherwise, fetch all permissions and update cache
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
@@ -67,6 +69,7 @@ def _fetch_permissions_for_permission_ids(
)
permissions_for_doc_id = []
# Update cache and return all permissions
for permission in fetched_permissions:
permissions_for_doc_id.append(permission)
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission

View File

@@ -1,9 +1,7 @@
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import BASE_SCOPES
from httpx_oauth.clients.openid import OpenID
from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
@@ -90,13 +88,7 @@ def get_application() -> FastAPI:
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
OpenID(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
OPENID_CONFIG_URL,
# BASE_SCOPES is the same as not setting this
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
),
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
auth_backend,
USER_AUTH_SECRET,
associate_by_email=True,

View File

@@ -58,7 +58,6 @@ class UserGroup(BaseModel):
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential
),
access_type=cc_pair_relationship.cc_pair.access_type,
)
for cc_pair_relationship in user_group_model.cc_pair_relationships
if cc_pair_relationship.is_current

View File

@@ -23,6 +23,7 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
preferences_data = cast(
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
)
print("preferences_data", preferences_data)
return UserPreferences(**preferences_data)
except KvKeyNotFoundError:
return UserPreferences(

View File

@@ -42,17 +42,8 @@ class UserCreate(schemas.BaseUserCreate):
tenant_id: str | None = None
class UserUpdateWithRole(schemas.BaseUserUpdate):
role: UserRole
class UserUpdate(schemas.BaseUserUpdate):
"""
Role updates are not allowed through the user update endpoint for security reasons
Role changes should be handled through a separate, admin-only process
"""
class AuthBackend(str, Enum):
REDIS = "redis"
POSTGRES = "postgres"

View File

@@ -33,8 +33,6 @@ from fastapi_users.authentication import AuthenticationBackend
from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication import RedisStrategy
from fastapi_users.authentication import Strategy
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.exceptions import UserAlreadyExists
from fastapi_users.jwt import decode_jwt
from fastapi_users.jwt import generate_jwt
@@ -54,15 +52,13 @@ from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.email_utils import send_forgot_password_email
from onyx.auth.email_utils import send_user_verification_email
from onyx.auth.invited_users import get_invited_users
from onyx.auth.schemas import AuthBackend
from onyx.auth.schemas import UserCreate
from onyx.auth.schemas import UserRole
from onyx.auth.schemas import UserUpdateWithRole
from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
from onyx.auth.schemas import UserUpdate
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
@@ -78,7 +74,6 @@ from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.api_key import fetch_user_for_api_key
from onyx.db.auth import get_access_token_db
from onyx.db.auth import get_default_admin_user_emails
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
@@ -87,7 +82,6 @@ from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.db.users import get_user_by_email
@@ -215,7 +209,7 @@ def verify_email_domain(email: str) -> None:
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
async def create(
@@ -245,8 +239,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
referral_source=referral_source,
request=request,
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
@@ -265,16 +261,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.role.is_web_login() and user_create.role.is_web_login():
user_update = UserUpdateWithRole(
user_update = UserUpdate(
password=user_create.password,
is_verified=user_create.is_verified,
role=user_create.role,
)
user = await self.update(user_update, user)
else:
@@ -282,6 +278,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
@@ -583,14 +580,6 @@ def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
"""
A custom strategy that fetches the actual async Redis connection inside each method.
@@ -599,7 +588,7 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
def __init__(
self,
lifetime_seconds: Optional[int] = SESSION_EXPIRE_TIME_SECONDS,
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
):
self.lifetime_seconds = lifetime_seconds
@@ -648,16 +637,9 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
await redis.delete(f"{self.key_prefix}{token}")
if AUTH_BACKEND == AuthBackend.REDIS:
auth_backend = AuthenticationBackend(
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
)
elif AUTH_BACKEND == AuthBackend.POSTGRES:
auth_backend = AuthenticationBackend(
name="postgres", transport=cookie_transport, get_strategy=get_database_strategy
)
else:
raise ValueError(f"Invalid auth backend: {AUTH_BACKEND}")
auth_backend = AuthenticationBackend(
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
)
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):

View File

@@ -23,8 +23,8 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_sqlalchemy_engine
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.httpx.httpx_pool import HttpxPool
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_delete import RedisConnectorDelete
@@ -280,6 +280,51 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
return
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Vespa: Readiness probe starting.")
while True:
try:
client = get_vespa_http_client()
response = client.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()
if response_dict["status"]["code"] == "up":
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Vespa: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Vespa: Readiness probe succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
@@ -317,8 +362,6 @@ def on_worker_ready(sender: Any, **kwargs: Any) -> None:
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
HttpxPool.close_all()
if not celery_is_worker_primary(sender):
return
@@ -467,13 +510,3 @@ def reset_tenant_id(
) -> None:
"""Signal handler to reset tenant ID in context var after task ends."""
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""
if not wait_for_vespa_with_timeout():
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)
raise WorkerShutdown(msg)

View File

@@ -1,5 +1,6 @@
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -7,6 +8,7 @@ from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
@@ -79,7 +81,7 @@ class DynamicTenantScheduler(PersistentScheduler):
cloud_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": task.get("kwargs", {}),
"kwargs": {},
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")
@@ -130,25 +132,21 @@ class DynamicTenantScheduler(PersistentScheduler):
# get current schedule and extract current tenants
current_schedule = self.schedule.items()
# there are no more per tenant beat tasks, so comment this out
# NOTE: we may not actualy need this scheduler any more and should
# test reverting to a regular beat schedule implementation
current_tenants = set()
for task_name, _ in current_schedule:
task_name = cast(str, task_name)
if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
continue
# current_tenants = set()
# for task_name, _ in current_schedule:
# task_name = cast(str, task_name)
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
# continue
if "_" in task_name:
# example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
# -> "12345678-abcd-efgh-ijkl-12345678"
current_tenants.add(task_name.split("_")[-1])
logger.info(f"Found {len(current_tenants)} existing items in schedule")
# if "_" in task_name:
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
# # -> "12345678-abcd-efgh-ijkl-12345678"
# current_tenants.add(task_name.split("_")[-1])
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
# for tenant_id in tenant_ids:
# if tenant_id not in current_tenants:
# logger.info(f"Processing new tenant: {tenant_id}")
for tenant_id in tenant_ids:
if tenant_id not in current_tenants:
logger.info(f"Processing new tenant: {tenant_id}")
new_schedule = self._generate_schedule(tenant_ids)

View File

@@ -62,7 +62,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -68,7 +68,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -10,10 +10,6 @@ from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
@@ -58,27 +54,16 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
logger.info("worker_init signal received.")
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
if MANAGED_VESPA:
httpx_init_vespa_pool(
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
ssl_cert=VESPA_CLOUD_CERT_PATH,
ssl_key=VESPA_CLOUD_KEY_PATH,
)
else:
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -86,7 +86,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
logger.info("Running as the primary celery worker.")

View File

@@ -1,13 +1,10 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
import httpx
from sqlalchemy.orm import Session
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
@@ -20,7 +17,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import TaskStatus
from onyx.db.models import TaskQueueState
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.server.documents.models import DeletionAttemptSnapshot
@@ -158,25 +154,3 @@ def celery_is_worker_primary(worker: Any) -> bool:
return True
return False
def httpx_init_vespa_pool(
max_keepalive_connections: int,
timeout: int = VESPA_REQUEST_TIMEOUT,
ssl_cert: str | None = None,
ssl_key: str | None = None,
) -> None:
httpx_cert = None
httpx_verify = False
if ssl_cert and ssl_key:
httpx_cert = cast(tuple[str, str], (ssl_cert, ssl_key))
httpx_verify = True
HttpxPool.init_client(
name="vespa",
cert=httpx_cert,
verify=httpx_verify,
timeout=timeout,
http2=False,
limits=httpx.Limits(max_keepalive_connections=max_keepalive_connections),
)

View File

@@ -16,241 +16,125 @@ from shared_configs.configs import MULTI_TENANT
# it's only important that they run relatively regularly
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# hack to slow down task dispatch in the cloud until
# we have a better implementation (backpressure, etc)
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 8
# tasks that only run in the cloud
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
# the name attribute must start with ONYX_CELERY_CLOUD_PREFIX = "cloud" to be filtered
# by the DynamicTenantScheduler
cloud_tasks_to_schedule = [
# cloud specific tasks
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
"schedule": timedelta(hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
# remaining tasks are cloud generators for per tenant tasks
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"task": OnyxCeleryTask.CLOUD_CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_INDEXING,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-connector-deletion",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-vespa-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-prune",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_PRUNING,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-vespa-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.MONITOR_VESPA_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=30 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-external-group-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-background-processes",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(minutes=5 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.LOW,
},
},
]
if LLM_MODEL_UPDATE_API_URL:
cloud_tasks_to_schedule.append(
# tasks that run in either self-hosted on cloud
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-prune",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "kombu-message-cleanup",
"task": OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
"schedule": timedelta(seconds=3600),
"options": {
"priority": OnyxCeleryPriority.LOWEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-vespa-sync",
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"schedule": timedelta(minutes=5),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
if not MULTI_TENANT:
tasks_to_schedule.append(
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-llm-model-update",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(
hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER
), # Check every hour
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"priority": OnyxCeleryPriority.LOW,
},
}
)
# tasks that run in either self-hosted on cloud
tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
tasks_to_schedule.extend(
[
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
# Only add the LLM model update task if the API URL is configured
if LLM_MODEL_UPDATE_API_URL:
tasks_to_schedule.append(
{
"name": "check-for-llm-model-update",
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"schedule": timedelta(hours=1), # Check every hour
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-pruning",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-vespa-sync",
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"schedule": timedelta(minutes=15),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
]
}
)
# Only add the LLM model update task if the API URL is configured
if LLM_MODEL_UPDATE_API_URL:
tasks_to_schedule.append(
{
"name": "check-for-llm-model-update",
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"schedule": timedelta(hours=1), # Check every hour
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return cloud_tasks_to_schedule

View File

@@ -33,7 +33,6 @@ class TaskDependencyError(RuntimeError):
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -140,6 +139,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
submitted=datetime.now(timezone.utc),
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
)
redis_connector.delete.set_fence(fence_payload)
try:
@@ -178,13 +184,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
)
if tasks_generated is None:
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
)
except TaskDependencyError:
redis_connector.delete.set_fence(None)
raise

View File

@@ -11,7 +11,6 @@ from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.document import upsert_document_external_perms
@@ -32,17 +31,12 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_doc_perm_sync import (
@@ -63,9 +57,6 @@ LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
"""Jobs / utils for kicking off doc permissions sync tasks."""
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external doc permissions sync is due."""
@@ -100,7 +91,6 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -183,19 +173,6 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_PERMISSIONS,
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncPayload(started=None, celery_task_id=None)
redis_connector.permissions.set_fence(payload)
result = app.send_task(
OnyxCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
kwargs=dict(
@@ -207,8 +184,11 @@ def try_creating_permissions_sync_task(
priority=OnyxCeleryPriority.HIGH,
)
# fill in the celery task id
payload.celery_task_id = result.id
# set a basic fence to start
payload = RedisConnectorPermissionSyncPayload(
started=None, celery_task_id=result.id
)
redis_connector.permissions.set_fence(payload)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
@@ -418,53 +398,3 @@ def update_external_document_permissions_task(
f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}"
)
return False
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
def monitor_ccpair_permissions_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.permissions.fenced:
return
initial = redis_connector.permissions.generator_complete
if initial is None:
return
remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_PERMISSIONS,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial,
)
redis_connector.permissions.reset()

View File

@@ -33,11 +33,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
@@ -95,7 +91,6 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -204,15 +199,6 @@ def try_creating_external_group_sync_task(
celery_task_id=result.id,
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_GROUP,
)
redis_connector.external_group_sync.set_fence(payload)
except Exception:
@@ -302,26 +288,11 @@ def connector_external_group_sync_generator_task(
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_GROUP,
sync_status=SyncStatus.SUCCESS,
)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_GROUP,
sync_status=SyncStatus.FAILED,
)
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
raise e

View File

@@ -15,7 +15,7 @@ from redis import Redis
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.indexing.utils import _should_index
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
@@ -23,32 +23,32 @@ from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_ta
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingMode
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.search_settings import get_active_search_settings_list
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_active_search_settings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.swap_index import check_index_swap
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
@@ -68,12 +68,15 @@ logger = setup_logger()
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
debug_tenants = {
"tenant_i-043470d740845ec56",
"tenant_82b497ce-88aa-4fbd-841a-92cae43529c8",
}
time_start = time.monotonic()
tasks_created = 0
locked = False
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client_replica = get_redis_replica_client(tenant_id=tenant_id)
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
@@ -120,18 +123,48 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# kick off index attempts
for cc_pair_id in cc_pair_ids:
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair lock: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
search_settings_list = get_active_search_settings_list(db_session)
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
for search_settings_instance in search_settings_list:
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair search settings lock: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
if redis_connector_index.fenced:
continue
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing get_connector_credential_pair_from_id: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -139,10 +172,28 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not cc_pair:
continue
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing get_last_attempt_for_cc_pair: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair should index: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
@@ -175,6 +226,15 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
cc_pair.id, None, db_session
)
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair try_creating_indexing_task: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
@@ -195,6 +255,24 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)
tasks_created += 1
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair try_creating_indexing_task finished: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing unfenced lock: "
f"tenant={tenant_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
# Fail any index attempts in the DB that don't have fences
@@ -204,7 +282,24 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
db_session, redis_client
)
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing after get unfenced lock: "
f"tenant={tenant_id} "
f"ttl={ttl}"
)
for attempt_id in unfenced_attempt_ids:
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing unfenced attempt id lock: "
f"tenant={tenant_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
attempt = get_index_attempt(db_session, attempt_id)
@@ -222,6 +317,15 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
attempt.id, db_session, failure_reason=failure_reason
)
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing validate fences lock: "
f"tenant={tenant_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
# we want to run this less frequently than the overall task
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
@@ -230,7 +334,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# or be currently executing
try:
validate_indexing_fences(
tenant_id, redis_client_replica, redis_client_celery, lock_beat
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
)
except Exception:
task_logger.exception("Exception while validating indexing fences")
@@ -304,14 +408,6 @@ def connector_indexing_task(
attempt_found = False
n_final_progress: int | None = None
# 20 is the documented default for httpx max_keepalive_connections
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
@@ -578,9 +674,6 @@ def connector_indexing_proxy_task(
while True:
sleep(5)
# renew watchdog signal (this has a shorter timeout than set_active)
redis_connector_index.set_watchdog(True)
# renew active signal
redis_connector_index.set_active()
@@ -687,10 +780,67 @@ def connector_indexing_proxy_task(
)
continue
redis_connector_index.set_watchdog(False)
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return
@shared_task(
name=OnyxCeleryTask.CLOUD_CHECK_FOR_INDEXING,
trail=False,
bind=True,
)
def cloud_check_for_indexing(self: Task) -> bool | None:
"""a lightweight task used to kick off individual check tasks for each tenant."""
time_start = time.monotonic()
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.CLOUD_CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
last_lock_time = time.monotonic()
try:
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()
last_lock_time = current_time
self.app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
kwargs=dict(
tenant_id=tenant_id,
),
priority=OnyxCeleryPriority.HIGH,
expires=BEAT_EXPIRES_DEFAULT,
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during cloud indexing check")
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error("cloud_check_for_indexing - Lock not owned on completion")
redis_lock_dump(lock_beat, redis_client)
time_elapsed = time.monotonic() - time_start
task_logger.info(
f"cloud_check_for_indexing finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
)
return True

View File

@@ -291,20 +291,17 @@ def validate_indexing_fence(
def validate_indexing_fences(
tenant_id: str | None,
r_replica: Redis,
celery_app: Celery,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
"""Validates all indexing fences for this tenant ... aka makes sure
indexing tasks sent to celery are still in flight.
"""
reserved_indexing_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
# Use replica for this because the worst thing that happens
# is that we don't run the validation on this pass
for key_bytes in r_replica.scan_iter(
# validate all existing indexing jobs
for key_bytes in r.scan_iter(
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
lock_beat.reacquire()

View File

@@ -14,16 +14,8 @@ from onyx.db.models import LLMProvider
def _process_model_list_response(model_list_json: Any) -> list[str]:
# Handle case where response is wrapped in a "data" field
if isinstance(model_list_json, dict):
if "data" in model_list_json:
model_list_json = model_list_json["data"]
elif "models" in model_list_json:
model_list_json = model_list_json["models"]
else:
raise ValueError(
"Invalid response from API - expected dict with 'data' or "
f"'models' field, got {type(model_list_json)}"
)
if isinstance(model_list_json, dict) and "data" in model_list_json:
model_list_json = model_list_json["data"]
if not isinstance(model_list_json, list):
raise ValueError(
@@ -35,18 +27,11 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
for item in model_list_json:
if isinstance(item, str):
model_names.append(item)
elif isinstance(item, dict):
if "model_name" in item:
model_names.append(item["model_name"])
elif "id" in item:
model_names.append(item["id"])
else:
raise ValueError(
f"Invalid item in model list - expected dict with model_name or id, got {type(item)}"
)
elif isinstance(item, dict) and "model_name" in item:
model_names.append(item["model_name"])
else:
raise ValueError(
f"Invalid item in model list - expected string or dict, got {type(item)}"
f"Invalid item in model list - expected string or dict with model_name, got {type(item)}"
)
return model_names
@@ -54,7 +39,6 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,

View File

@@ -1,10 +1,7 @@
import json
import time
from collections.abc import Callable
from datetime import timedelta
from itertools import islice
from typing import Any
from typing import Literal
from celery import shared_task
from celery import Task
@@ -13,34 +10,26 @@ from pydantic import BaseModel
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import IndexAttempt
from onyx.db.models import SyncRecord
from onyx.db.models import UserGroup
from onyx.db.search_settings import get_active_search_settings_list
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
@@ -52,17 +41,6 @@ _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT = (
"monitoring_connector_index_attempt_run_success:{cc_pair_id}:{index_attempt_id}"
)
_FINAL_METRIC_KEY_FMT = "sync_final_metrics:{sync_type}:{entity_id}:{sync_record_id}"
_SYNC_START_LATENCY_KEY_FMT = (
"sync_start_latency:{sync_type}:{entity_id}:{sync_record_id}"
)
_CONNECTOR_START_TIME_KEY_FMT = "connector_start_time:{cc_pair_id}:{index_attempt_id}"
_CONNECTOR_END_TIME_KEY_FMT = "connector_end_time:{cc_pair_id}:{index_attempt_id}"
_SYNC_START_TIME_KEY_FMT = "sync_start_time:{sync_type}:{entity_id}:{sync_record_id}"
_SYNC_END_TIME_KEY_FMT = "sync_end_time:{sync_type}:{entity_id}:{sync_record_id}"
def _mark_metric_as_emitted(redis_std: Redis, key: str) -> None:
"""Mark a metric as having been emitted by setting a Redis key with expiration"""
@@ -125,7 +103,6 @@ class Metric(BaseModel):
}.items()
if v is not None
}
task_logger.info(f"Emitting metric: {data}")
optional_telemetry(
record_type=RecordType.METRIC,
data=data,
@@ -200,111 +177,45 @@ def _build_connector_start_latency_metric(
start_latency = (recent_attempt.time_started - desired_start_time).total_seconds()
task_logger.info(
f"Start latency for index attempt {recent_attempt.id}: {start_latency:.2f}s "
f"(desired: {desired_start_time}, actual: {recent_attempt.time_started})"
)
job_id = build_job_id("connector", str(cc_pair.id), str(recent_attempt.id))
return Metric(
key=metric_key,
name="connector_start_latency",
value=start_latency,
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
tags={},
)
def _build_connector_final_metrics(
def _build_run_success_metrics(
cc_pair: ConnectorCredentialPair,
recent_attempts: list[IndexAttempt],
redis_std: Redis,
) -> list[Metric]:
"""
Final metrics for connector index attempts:
- Boolean success/fail metric
- If success, emit:
* duration (seconds)
* doc_count
"""
metrics = []
for attempt in recent_attempts:
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=attempt.id,
)
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.info(
f"Skipping final metrics for connector {cc_pair.connector.id} "
f"index attempt {attempt.id}, already emitted."
f"Skipping metric for connector {cc_pair.connector.id} "
f"index attempt {attempt.id} because it has already been "
"emitted"
)
continue
# We only emit final metrics if the attempt is in a terminal state
if attempt.status not in [
if attempt.status in [
IndexingStatus.SUCCESS,
IndexingStatus.FAILED,
IndexingStatus.CANCELED,
]:
# Not finished; skip
continue
job_id = build_job_id("connector", str(cc_pair.id), str(attempt.id))
success = attempt.status == IndexingStatus.SUCCESS
metrics.append(
Metric(
key=metric_key, # We'll mark the same key for any final metrics
name="connector_run_succeeded",
value=success,
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
"status": attempt.status.value,
},
)
)
if success:
# Make sure we have valid time_started
if attempt.time_started and attempt.time_updated:
duration_seconds = (
attempt.time_updated - attempt.time_started
).total_seconds()
metrics.append(
Metric(
key=None, # No need for a new key, or you can reuse the same if you prefer
name="connector_index_duration_seconds",
value=duration_seconds,
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
)
)
else:
task_logger.error(
f"Index attempt {attempt.id} succeeded but has missing time "
f"(time_started={attempt.time_started}, time_updated={attempt.time_updated})."
)
# For doc counts, choose whichever field is more relevant
doc_count = attempt.total_docs_indexed or 0
metrics.append(
Metric(
key=None,
name="connector_index_doc_count",
value=doc_count,
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
key=metric_key,
name="connector_run_succeeded",
value=attempt.status == IndexingStatus.SUCCESS,
tags={"source": str(cc_pair.connector.source)},
)
)
@@ -313,337 +224,178 @@ def _build_connector_final_metrics(
def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
"""Collect metrics about connector runs from the past hour"""
# NOTE: use get_db_current_time since the IndexAttempt times are set based on DB time
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
# Get all connector credential pairs
cc_pairs = db_session.scalars(select(ConnectorCredentialPair)).all()
# Might be more than one search setting, or just one
active_search_settings_list = get_active_search_settings_list(db_session)
metrics = []
# If you want to process each cc_pair against each search setting:
for cc_pair in cc_pairs:
for search_settings in active_search_settings_list:
recent_attempts = (
db_session.query(IndexAttempt)
.filter(
IndexAttempt.connector_credential_pair_id == cc_pair.id,
IndexAttempt.search_settings_id == search_settings.id,
)
.order_by(IndexAttempt.time_created.desc())
.limit(2)
.all()
# Get all attempts in the last hour
recent_attempts = (
db_session.query(IndexAttempt)
.filter(
IndexAttempt.connector_credential_pair_id == cc_pair.id,
IndexAttempt.time_created >= one_hour_ago,
)
.order_by(IndexAttempt.time_created.desc())
.all()
)
most_recent_attempt = recent_attempts[0] if recent_attempts else None
second_most_recent_attempt = (
recent_attempts[1] if len(recent_attempts) > 1 else None
)
if not recent_attempts:
continue
# if no metric to emit, skip
if most_recent_attempt is None:
continue
most_recent_attempt = recent_attempts[0]
second_most_recent_attempt = (
recent_attempts[1] if len(recent_attempts) > 1 else None
)
# Connector start latency
start_latency_metric = _build_connector_start_latency_metric(
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
)
if start_latency_metric:
metrics.append(start_latency_metric)
if one_hour_ago > most_recent_attempt.time_created:
continue
# Build a job_id for correlation
job_id = build_job_id(
"connector", str(cc_pair.id), str(most_recent_attempt.id)
)
# Add raw start time metric if available
if most_recent_attempt.time_started:
start_time_key = _CONNECTOR_START_TIME_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=most_recent_attempt.id,
)
metrics.append(
Metric(
key=start_time_key,
name="connector_start_time",
value=most_recent_attempt.time_started.timestamp(),
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
)
)
# Add raw end time metric if available and in terminal state
if (
most_recent_attempt.status.is_terminal()
and most_recent_attempt.time_updated
):
end_time_key = _CONNECTOR_END_TIME_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=most_recent_attempt.id,
)
metrics.append(
Metric(
key=end_time_key,
name="connector_end_time",
value=most_recent_attempt.time_updated.timestamp(),
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
)
)
# Connector start latency
start_latency_metric = _build_connector_start_latency_metric(
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
)
if start_latency_metric:
metrics.append(start_latency_metric)
# Connector run success/failure
final_metrics = _build_connector_final_metrics(
cc_pair, recent_attempts, redis_std
)
metrics.extend(final_metrics)
# Connector run success/failure
run_success_metrics = _build_run_success_metrics(
cc_pair, recent_attempts, redis_std
)
metrics.extend(run_success_metrics)
return metrics
def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
"""
Collect metrics for document set and group syncing:
- Success/failure status
- Start latency (for doc sets / user groups)
- Duration & doc count (only if success)
- Throughput (docs/min) (only if success)
- Raw start/end times for each sync
"""
"""Collect metrics about document set and group syncing speed"""
# NOTE: use get_db_current_time since the SyncRecord times are set based on DB time
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
# Get all sync records that ended in the last hour
# Get all sync records from the last hour
recent_sync_records = db_session.scalars(
select(SyncRecord)
.where(SyncRecord.sync_end_time.isnot(None))
.where(SyncRecord.sync_end_time >= one_hour_ago)
.order_by(SyncRecord.sync_end_time.desc())
.where(SyncRecord.sync_start_time >= one_hour_ago)
.order_by(SyncRecord.sync_start_time.desc())
).all()
task_logger.info(
f"Collecting sync metrics for {len(recent_sync_records)} sync records"
)
metrics = []
for sync_record in recent_sync_records:
# Build a job_id for correlation
job_id = build_job_id("sync_record", str(sync_record.id))
# Skip if no end time (sync still in progress)
if not sync_record.sync_end_time:
continue
# Add raw start time metric
start_time_key = _SYNC_START_TIME_KEY_FMT.format(
sync_type=sync_record.sync_type,
entity_id=sync_record.entity_id,
sync_record_id=sync_record.id,
# Check if we already emitted a metric for this sync record
metric_key = (
f"sync_speed:{sync_record.sync_type}:"
f"{sync_record.entity_id}:{sync_record.id}"
)
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.debug(
f"Skipping metric for sync record {sync_record.id} "
"because it has already been emitted"
)
continue
# Calculate sync duration in minutes
sync_duration_mins = (
sync_record.sync_end_time - sync_record.sync_start_time
).total_seconds() / 60.0
# Calculate sync speed (docs/min) - avoid division by zero
sync_speed = (
sync_record.num_docs_synced / sync_duration_mins
if sync_duration_mins > 0
else None
)
if sync_speed is None:
task_logger.error(
"Something went wrong with sync speed calculation. "
f"Sync record: {sync_record.id}"
)
continue
metrics.append(
Metric(
key=start_time_key,
name="sync_start_time",
value=sync_record.sync_start_time.timestamp(),
key=metric_key,
name="sync_speed_docs_per_min",
value=sync_speed,
tags={
"sync_type": str(sync_record.sync_type),
"status": str(sync_record.sync_status),
},
)
)
# Add sync start latency metric
start_latency_key = (
f"sync_start_latency:{sync_record.sync_type}"
f":{sync_record.entity_id}:{sync_record.id}"
)
if _has_metric_been_emitted(redis_std, start_latency_key):
task_logger.debug(
f"Skipping start latency metric for sync record {sync_record.id} "
"because it has already been emitted"
)
continue
# Get the entity's last update time based on sync type
entity: DocumentSet | UserGroup | None = None
if sync_record.sync_type == SyncType.DOCUMENT_SET:
entity = db_session.scalar(
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
)
elif sync_record.sync_type == SyncType.USER_GROUP:
entity = db_session.scalar(
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
)
else:
# Skip other sync types
task_logger.debug(
f"Skipping sync record {sync_record.id} "
f"with type {sync_record.sync_type} "
f"and id {sync_record.entity_id} "
"because it is not a document set or user group"
)
continue
if entity is None:
task_logger.error(
f"Could not find entity for sync record {sync_record.id} "
f"with type {sync_record.sync_type} and id {sync_record.entity_id}"
)
continue
# Calculate start latency in seconds
start_latency = (
sync_record.sync_start_time - entity.time_last_modified_by_user
).total_seconds()
if start_latency < 0:
task_logger.error(
f"Start latency is negative for sync record {sync_record.id} "
f"with type {sync_record.sync_type} and id {sync_record.entity_id}."
"This is likely because the entity was updated between the time the "
"time the sync finished and this job ran. Skipping."
)
continue
metrics.append(
Metric(
key=start_latency_key,
name="sync_start_latency_seconds",
value=start_latency,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
# Add raw end time metric if available
if sync_record.sync_end_time:
end_time_key = _SYNC_END_TIME_KEY_FMT.format(
sync_type=sync_record.sync_type,
entity_id=sync_record.entity_id,
sync_record_id=sync_record.id,
)
metrics.append(
Metric(
key=end_time_key,
name="sync_end_time",
value=sync_record.sync_end_time.timestamp(),
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
# Emit a SUCCESS/FAIL boolean metric
# Use a single Redis key to avoid re-emitting final metrics
final_metric_key = _FINAL_METRIC_KEY_FMT.format(
sync_type=sync_record.sync_type,
entity_id=sync_record.entity_id,
sync_record_id=sync_record.id,
)
if not _has_metric_been_emitted(redis_std, final_metric_key):
# Evaluate success
sync_succeeded = sync_record.sync_status == SyncStatus.SUCCESS
metrics.append(
Metric(
key=final_metric_key,
name="sync_run_succeeded",
value=sync_succeeded,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
"status": str(sync_record.sync_status),
},
)
)
# If successful, emit additional metrics
if sync_succeeded:
if sync_record.sync_end_time and sync_record.sync_start_time:
duration_seconds = (
sync_record.sync_end_time - sync_record.sync_start_time
).total_seconds()
else:
task_logger.error(
f"Invalid times for sync record {sync_record.id}: "
f"start={sync_record.sync_start_time}, end={sync_record.sync_end_time}"
)
duration_seconds = None
doc_count = sync_record.num_docs_synced or 0
sync_speed = None
if duration_seconds and duration_seconds > 0:
duration_mins = duration_seconds / 60.0
sync_speed = (
doc_count / duration_mins if duration_mins > 0 else None
)
# Emit duration, doc count, speed
if duration_seconds is not None:
metrics.append(
Metric(
key=final_metric_key,
name="sync_duration_seconds",
value=duration_seconds,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
else:
task_logger.error(
f"Invalid sync record {sync_record.id} with no duration"
)
metrics.append(
Metric(
key=final_metric_key,
name="sync_doc_count",
value=doc_count,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
if sync_speed is not None:
metrics.append(
Metric(
key=final_metric_key,
name="sync_speed_docs_per_min",
value=sync_speed,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
else:
task_logger.error(
f"Invalid sync record {sync_record.id} with no duration"
)
# Emit start latency
start_latency_key = _SYNC_START_LATENCY_KEY_FMT.format(
sync_type=sync_record.sync_type,
entity_id=sync_record.entity_id,
sync_record_id=sync_record.id,
)
if not _has_metric_been_emitted(redis_std, start_latency_key):
# Get the entity's last update time based on sync type
entity: DocumentSet | UserGroup | None = None
if sync_record.sync_type == SyncType.DOCUMENT_SET:
entity = db_session.scalar(
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
)
elif sync_record.sync_type == SyncType.USER_GROUP:
entity = db_session.scalar(
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
)
if entity is None:
task_logger.error(
f"Sync record of type {sync_record.sync_type} doesn't have an entity "
f"associated with it (id={sync_record.entity_id}). Skipping start latency metric."
)
# Calculate start latency in seconds:
# (actual sync start) - (last modified time)
if (
entity is not None
and entity.time_last_modified_by_user
and sync_record.sync_start_time
):
start_latency = (
sync_record.sync_start_time - entity.time_last_modified_by_user
).total_seconds()
if start_latency < 0:
task_logger.error(
f"Negative start latency for sync record {sync_record.id} "
f"(start={sync_record.sync_start_time}, entity_modified={entity.time_last_modified_by_user})"
)
continue
metrics.append(
Metric(
key=start_latency_key,
name="sync_start_latency_seconds",
value=start_latency,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
return metrics
def build_job_id(
job_type: Literal["connector", "sync_record"],
primary_id: str,
secondary_id: str | None = None,
) -> str:
if job_type == "connector":
if secondary_id is None:
raise ValueError(
"secondary_id (attempt_id) is required for connector job_type"
)
return f"connector:{primary_id}:attempt:{secondary_id}"
elif job_type == "sync_record":
return f"sync_record:{primary_id}"
@shared_task(
name=OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
ignore_result=True,
soft_time_limit=_MONITORING_SOFT_TIME_LIMIT,
time_limit=_MONITORING_TIME_LIMIT,
queue=OnyxCeleryQueues.MONITORING,
@@ -681,18 +433,14 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
lambda: _collect_connector_metrics(db_session, redis_std),
lambda: _collect_sync_metrics(db_session, redis_std),
]
# Collect and log each metric
with get_session_with_tenant(tenant_id) as db_session:
for metric_fn in metric_functions:
metrics = metric_fn()
for metric in metrics:
# double check to make sure we aren't double-emitting metrics
if metric.key is not None and not _has_metric_been_emitted(
redis_std, metric.key
):
metric.log()
metric.emit(tenant_id)
metric.log()
metric.emit(tenant_id)
if metric.key:
_mark_metric_as_emitted(redis_std, metric.key)
task_logger.info("Successfully collected background metrics")
@@ -708,116 +456,3 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
lock_monitoring.release()
task_logger.info("Background monitoring task finished")
@shared_task(
name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
)
def cloud_check_alembic() -> bool | None:
"""A task to verify that all tenants are on the same alembic revision.
This check is expected to fail if a cloud alembic migration is currently running
across all tenants.
TODO: have the cloud migration script set an activity signal that this check
uses to know it doesn't make sense to run a check at the present time.
"""
time_start = time.monotonic()
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.CLOUD_CHECK_ALEMBIC_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
last_lock_time = time.monotonic()
tenant_to_revision: dict[str, str | None] = {}
revision_counts: dict[str, int] = {}
out_of_date_tenants: dict[str, str | None] = {}
top_revision: str = ""
try:
# map each tenant_id to its revision
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()
last_lock_time = current_time
if tenant_id is None:
continue
with get_session_with_tenant(tenant_id=None) as session:
result = session.execute(
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
)
result_scalar: str | None = result.scalar_one_or_none()
tenant_to_revision[tenant_id] = result_scalar
# get the total count of each revision
for k, v in tenant_to_revision.items():
if v is None:
continue
revision_counts[v] = revision_counts.get(v, 0) + 1
# get the revision with the most counts
sorted_revision_counts = sorted(
revision_counts.items(), key=lambda item: item[1], reverse=True
)
if len(sorted_revision_counts) == 0:
task_logger.error(
f"cloud_check_alembic - No revisions found for {len(tenant_ids)} tenant ids!"
)
else:
top_revision, _ = sorted_revision_counts[0]
# build a list of out of date tenants
for k, v in tenant_to_revision.items():
if v == top_revision:
continue
out_of_date_tenants[k] = v
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during cloud alembic check")
raise
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error("cloud_check_alembic - Lock not owned on completion")
redis_lock_dump(lock_beat, redis_client)
if len(out_of_date_tenants) > 0:
task_logger.error(
f"Found out of date tenants: "
f"num_out_of_date_tenants={len(out_of_date_tenants)} "
f"num_tenants={len(tenant_ids)} "
f"revision={top_revision}"
)
for k, v in islice(out_of_date_tenants.items(), 5):
task_logger.info(f"Out of date tenant: tenant={k} revision={v}")
else:
task_logger.info(
f"All tenants are up to date: num_tenants={len(tenant_ids)} revision={top_revision}"
)
time_elapsed = time.monotonic() - time_start
task_logger.info(
f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
)
return True

View File

@@ -25,18 +25,13 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import InputType
from onyx.db.connector import mark_ccpair_as_pruned
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_documents_for_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import pruning_ctx
@@ -45,9 +40,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
"""Jobs / utils for kicking off pruning tasks."""
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if pruning is due.
@@ -86,7 +78,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_PRUNING,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -212,14 +203,6 @@ def try_creating_prune_generator_task(
priority=OnyxCeleryPriority.LOW,
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=cc_pair.id,
sync_type=SyncType.PRUNING,
)
# set this only after all tasks have been added
redis_connector.prune.set_fence(True)
except Exception:
@@ -364,52 +347,3 @@ def connector_pruning_generator_task(
lock.release()
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")
"""Monitoring pruning utils, called in monitor_vespa_sync"""
def monitor_ccpair_pruning_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.prune.fenced:
return
initial = redis_connector.prune.generator_complete
if initial is None:
return
remaining = redis_connector.prune.get_remaining()
task_logger.info(
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.PRUNING,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial,
)
redis_connector.prune.taskset_clear()
redis_connector.prune.generator_clear()
redis_connector.prune.set_fence(False)

View File

@@ -1,22 +1,15 @@
import time
from http import HTTPStatus
import httpx
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from tenacity import RetryError
from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.document import delete_document_by_connector_credential_pair__no_commit
from onyx.db.document import delete_documents_complete__no_commit
from onyx.db.document import fetch_chunk_count_for_document
@@ -25,16 +18,11 @@ from onyx.db.document import get_document_connector_count
from onyx.db.document import mark_document_as_modified
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_tenant
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
@@ -80,11 +68,9 @@ def document_by_cc_pair_cleanup_task(
action = "skip"
chunks_affected = 0
active_search_settings = get_active_search_settings(db_session)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
doc_index = get_default_document_index(
active_search_settings.primary,
active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
retry_index = RetryDocumentIndex(doc_index)
@@ -213,78 +199,3 @@ def document_by_cc_pair_cleanup_task(
return False
return True
@shared_task(
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
ignore_result=True,
trail=False,
bind=True,
)
def cloud_beat_task_generator(
self: Task,
task_name: str,
queue: str = OnyxCeleryTask.DEFAULT,
priority: int = OnyxCeleryPriority.MEDIUM,
expires: int = BEAT_EXPIRES_DEFAULT,
) -> bool | None:
"""a lightweight task used to kick off individual beat tasks per tenant."""
time_start = time.monotonic()
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_beat: RedisLock = redis_client.lock(
f"{OnyxRedisLocks.CLOUD_BEAT_TASK_GENERATOR_LOCK}:{task_name}",
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
last_lock_time = time.monotonic()
try:
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()
last_lock_time = current_time
# needed in the cloud
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
continue
self.app.send_task(
task_name,
kwargs=dict(
tenant_id=tenant_id,
),
queue=queue,
priority=priority,
expires=expires,
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during cloud_beat_task_generator")
finally:
if not lock_beat.owned():
task_logger.error(
"cloud_beat_task_generator - Lock not owned on completion"
)
redis_lock_dump(lock_beat, redis_client)
else:
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.info(
f"cloud_beat_task_generator finished: "
f"task={task_name} "
f"num_tenants={len(tenant_ids)} "
f"elapsed={time_elapsed:.2f}"
)
return True

View File

@@ -24,10 +24,6 @@ from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
monitor_ccpair_permissions_taskset,
)
from onyx.background.celery.tasks.pruning.tasks import monitor_ccpair_pruning_taskset
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
@@ -38,6 +34,8 @@ from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector import mark_ccpair_as_pruned
from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
@@ -63,22 +61,23 @@ from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.models import DocumentSet
from onyx.db.models import UserGroup
from onyx.db.search_settings import get_active_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
)
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_usergroup import RedisUserGroup
@@ -98,7 +97,6 @@ logger = setup_logger()
# which bloats the result metadata considerably. trail=False prevents this.
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -652,6 +650,83 @@ def monitor_connector_deletion_taskset(
redis_connector.delete.reset()
def monitor_ccpair_pruning_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.prune.fenced:
return
initial = redis_connector.prune.generator_complete
if initial is None:
return
remaining = redis_connector.prune.get_remaining()
task_logger.info(
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
)
redis_connector.prune.taskset_clear()
redis_connector.prune.generator_clear()
redis_connector.prune.set_fence(False)
def monitor_ccpair_permissions_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.permissions.fenced:
return
initial = redis_connector.permissions.generator_complete
if initial is None:
return
remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.reset()
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
@@ -660,7 +735,7 @@ def monitor_ccpair_indexing_taskset(
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"Connector indexing: could not parse composite_id from {fence_key}"
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
)
return
@@ -710,7 +785,6 @@ def monitor_ccpair_indexing_taskset(
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# Verify: if the generator isn't complete, the task must not be in READY state
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
@@ -756,7 +830,7 @@ def monitor_ccpair_indexing_taskset(
)
except Exception:
task_logger.exception(
"Connector indexing - Transient exception marking index attempt as failed: "
"monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: "
f"attempt={payload.index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
@@ -766,20 +840,6 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
return
if redis_connector_index.watchdog_signaled():
# if the generator is complete, don't clean up until the watchdog has exited
task_logger.info(
f"Connector indexing - Delaying finalization until watchdog has exited: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
return
status_enum = HTTPStatus(status_int)
task_logger.info(
@@ -796,20 +856,11 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
@shared_task(
name=OnyxCeleryTask.MONITOR_VESPA_SYNC,
ignore_result=True,
soft_time_limit=300,
bind=True,
)
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
"""This is a celery beat task that monitors and finalizes various long running tasks.
The name monitor_vespa_sync is a bit of a misnomer since it checks many different tasks
now. Should change that at some point.
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
For many tasks, the count is 0, that means all tasks finished and we should clean up.
If the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
@@ -825,17 +876,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
# Replica usage notes
#
# False negatives are OK. (aka fail to to see a key that exists on the master).
# We simply skip the monitoring work and it will be caught on the next pass.
#
# False positives are not OK, and are possible if we clear a fence on the master and
# then read from the replica. In this case, monitoring work could be done on a fence
# that no longer exists. To avoid this, we scan from the replica, but double check
# the result on the master.
r_replica = get_redis_replica_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
@@ -895,19 +935,17 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
# scan and monitor activity to completion
phase_start = time.monotonic()
lock_beat.reacquire()
if r_replica.exists(RedisConnectorCredentialPair.get_fence_key()):
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
timings["connector"] = time.monotonic() - phase_start
timings["connector_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisConnectorDelete.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
lock_beat.reacquire()
timings["connector_deletion"] = time.monotonic() - phase_start
@@ -917,74 +955,66 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisDocumentSet.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["documentset"] = time.monotonic() - phase_start
timings["documentset_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisUserGroup.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
monitor_usergroup_taskset = (
fetch_versioned_implementation_with_fallback(
"onyx.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
"onyx.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["usergroup"] = time.monotonic() - phase_start
timings["usergroup_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["pruning"] = time.monotonic() - phase_start
timings["pruning_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["indexing"] = time.monotonic() - phase_start
timings["indexing_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
if r.exists(key_bytes):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(
tenant_id, key_bytes, r, db_session
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["permissions"] = time.monotonic() - phase_start
timings["permissions_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -1015,15 +1045,11 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
def vespa_metadata_sync_task(
self: Task, document_id: str, tenant_id: str | None
) -> bool:
start = time.monotonic()
try:
with get_session_with_tenant(tenant_id) as db_session:
active_search_settings = get_active_search_settings(db_session)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
doc_index = get_default_document_index(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
retry_index = RetryDocumentIndex(doc_index)
@@ -1069,13 +1095,7 @@ def vespa_metadata_sync_task(
# r = get_redis_client(tenant_id=tenant_id)
# r.delete(redis_syncing_key)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
except Exception as ex:

View File

@@ -35,7 +35,6 @@ from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.document_index.factory import get_default_document_index
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
@@ -220,10 +219,9 @@ def _run_indexing(
callback=callback,
)
# Indexing is only done into one index at a time
document_index = get_default_document_index(
index_attempt_start.search_settings,
None,
httpx_client=HttpxPool.get("vespa"),
primary_index_name=ctx.index_name, secondary_index_name=None
)
indexing_pipeline = build_indexing_pipeline(

View File

@@ -254,7 +254,6 @@ def _get_force_search_settings(
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
new_msg_req.search_doc_ids,
new_msg_req.query_override is not None,
DISABLE_LLM_CHOOSE_SEARCH,
]
)
@@ -426,7 +425,9 @@ def stream_chat_message_objects(
)
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(search_settings, None)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
)
# Every chat Session begins with an empty root message
root_message = get_or_create_root_message(
@@ -498,6 +499,14 @@ def stream_chat_message_objects(
f"existing assistant message id: {existing_assistant_message_id}"
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
# leads to worst search quality
if not history_msgs:
new_msg_req.query_override = (
new_msg_req.query_override or new_msg_req.message
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session

View File

@@ -15,12 +15,11 @@ from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import check_message_tokens
from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.llm.utils import model_supports_image_input
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import add_date_time_to_prompt
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
@@ -32,16 +31,15 @@ def default_build_system_message(
prompt_config: PromptConfig,
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
tag_handled_prompt = handle_onyx_date_awareness(
system_prompt,
prompt_config,
add_additional_info_if_no_tag=prompt_config.datetime_aware,
)
if prompt_config.datetime_aware:
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
if not tag_handled_prompt:
if not system_prompt:
return None
return SystemMessage(content=tag_handled_prompt)
system_msg = SystemMessage(content=system_prompt)
return system_msg
def default_build_user_message(
@@ -66,11 +64,8 @@ def default_build_user_message(
else user_query
)
user_prompt = user_prompt.strip()
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
user_msg = HumanMessage(
content=build_content_with_imgs(tag_handled_prompt, files)
if files
else tag_handled_prompt
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
)
return user_msg
@@ -91,7 +86,6 @@ class AnswerPromptBuilder:
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
)
self.llm_config = llm_config
self.llm_tokenizer_encode_func = cast(
Callable[[str], list[int]], llm_tokenizer.encode
)
@@ -100,21 +94,12 @@ class AnswerPromptBuilder:
(
self.message_history,
self.history_token_cnts,
) = translate_history_to_basemessages(
message_history,
exclude_images=not model_supports_image_input(
self.llm_config.model_name,
self.llm_config.model_provider,
),
)
) = translate_history_to_basemessages(message_history)
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(
user_message,
self.llm_tokenizer_encode_func,
),
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
)
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []

View File

@@ -21,9 +21,9 @@ from onyx.prompts.constants import DEFAULT_IGNORE_STATEMENT
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import add_date_time_to_prompt
from onyx.prompts.prompt_utils import build_complete_context_str
from onyx.prompts.prompt_utils import build_task_prompt_reminders
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
from onyx.prompts.token_counts import (
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
@@ -127,11 +127,10 @@ def build_citations_system_message(
system_prompt = prompt_config.system_prompt.strip()
if prompt_config.include_citations:
system_prompt += REQUIRE_CITATION_STATEMENT
tag_handled_prompt = handle_onyx_date_awareness(
system_prompt, prompt_config, add_additional_info_if_no_tag=True
)
if prompt_config.datetime_aware:
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
return SystemMessage(content=tag_handled_prompt)
return SystemMessage(content=system_prompt)
def build_citations_user_message(

View File

@@ -9,8 +9,8 @@ from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.prompts.direct_qa_prompts import CONTEXT_BLOCK
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.direct_qa_prompts import JSON_PROMPT
from onyx.prompts.prompt_utils import add_date_time_to_prompt
from onyx.prompts.prompt_utils import build_complete_context_str
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
def _build_strong_llm_quotes_prompt(
@@ -39,11 +39,10 @@ def _build_strong_llm_quotes_prompt(
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
).strip()
tag_handled_prompt = handle_onyx_date_awareness(
full_prompt, prompt, add_additional_info_if_no_tag=True
)
if prompt.datetime_aware:
full_prompt = add_date_time_to_prompt(prompt_str=full_prompt)
return HumanMessage(content=tag_handled_prompt)
return HumanMessage(content=full_prompt)
def build_quotes_user_message(

View File

@@ -11,7 +11,6 @@ from onyx.llm.utils import build_content_with_imgs
def translate_onyx_msg_to_langchain(
msg: ChatMessage | PreviousMessage,
exclude_images: bool = False,
) -> BaseMessage:
files: list[InMemoryChatFile] = []
@@ -19,9 +18,7 @@ def translate_onyx_msg_to_langchain(
# attached. Just ignore them for now.
if not isinstance(msg, ChatMessage):
files = msg.files
content = build_content_with_imgs(
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
)
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
@@ -35,12 +32,9 @@ def translate_onyx_msg_to_langchain(
def translate_history_to_basemessages(
history: list[ChatMessage] | list["PreviousMessage"],
exclude_images: bool = False,
) -> tuple[list[BaseMessage], list[int]]:
history_basemessages = [
translate_onyx_msg_to_langchain(msg, exclude_images)
for msg in history
if msg.token_count != 0
translate_onyx_msg_to_langchain(msg) for msg in history if msg.token_count != 0
]
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
return history_basemessages, history_token_counts

View File

@@ -3,7 +3,6 @@ import os
import urllib.parse
from typing import cast
from onyx.auth.schemas import AuthBackend
from onyx.configs.constants import AuthType
from onyx.configs.constants import DocumentIndexType
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
@@ -56,12 +55,12 @@ MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)
AUTH_BACKEND = AuthBackend(os.environ.get("AUTH_BACKEND") or AuthBackend.REDIS.value)
REDIS_AUTH_EXPIRE_TIME_SECONDS = int(
os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS")
or os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS")
or 86400 * 7
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
# Default request timeout, mostly used by connectors
@@ -93,12 +92,6 @@ OAUTH_CLIENT_SECRET = (
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
# Duration (in seconds) for which the FastAPI Users JWT token remains valid in the user's browser.
# By default, this is set to match the Redis expiry time for consistency.
AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
os.environ.get("AUTH_COOKIE_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
# for basic auth
REQUIRE_EMAIL_VERIFICATION = (
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
@@ -200,8 +193,6 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
# this assumes that other redis settings remain the same as the primary
REDIS_REPLICA_HOST = os.environ.get("REDIS_REPLICA_HOST") or REDIS_HOST
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"

View File

@@ -294,8 +294,7 @@ class OnyxRedisLocks:
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
CLOUD_BEAT_TASK_GENERATOR_LOCK = "da_lock:cloud_beat_task_generator"
CLOUD_CHECK_ALEMBIC_BEAT_LOCK = "da_lock:cloud_check_alembic"
CLOUD_CHECK_INDEXING_BEAT_LOCK = "da_lock:cloud_check_indexing_beat"
class OnyxRedisSignals:
@@ -318,11 +317,6 @@ ONYX_CLOUD_TENANT_ID = "cloud"
class OnyxCeleryTask:
DEFAULT = "celery"
CLOUD_BEAT_TASK_GENERATOR = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_generate_beat_tasks"
CLOUD_CHECK_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_alembic"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
CHECK_FOR_INDEXING = "check_for_indexing"
@@ -330,10 +324,8 @@ class OnyxCeleryTask:
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"
@@ -351,6 +343,8 @@ class OnyxCeleryTask:
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
CLOUD_CHECK_FOR_INDEXING = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_for_indexing"
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15

View File

@@ -20,9 +20,9 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# NOTE: all are made lowercase to avoid case sensitivity issues
# These field types are considered metadata by default when
# treat_all_non_attachment_fields_as_metadata is False
DEFAULT_METADATA_FIELD_TYPES = {
# these are the field types that are considered metadata rather
# than sections
_METADATA_FIELD_TYPES = {
"singlecollaborator",
"collaborator",
"createdby",
@@ -60,35 +60,21 @@ class AirtableConnector(LoadConnector):
self,
base_id: str,
table_name_or_id: str,
treat_all_non_attachment_fields_as_metadata: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.base_id = base_id
self.table_name_or_id = table_name_or_id
self.batch_size = batch_size
self.airtable_client: AirtableApi | None = None
self.treat_all_non_attachment_fields_as_metadata = (
treat_all_non_attachment_fields_as_metadata
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
return None
@staticmethod
def _extract_field_values(
field_id: str,
field_info: Any,
field_type: str,
base_id: str,
table_id: str,
view_id: str | None,
record_id: str,
) -> list[tuple[str, str]]:
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
"""
Extract value(s) + links from a field regardless of its type.
Attachments are represented as multiple sections, and therefore
returned as a list of tuples (value, link).
Extract value(s) from a field regardless of its type.
Returns either a single string or list of strings for attachments.
"""
if field_info is None:
return []
@@ -99,11 +85,8 @@ class AirtableConnector(LoadConnector):
if field_type == "multipleRecordLinks":
return []
# default link to use for non-attachment fields
default_link = f"https://airtable.com/{base_id}/{table_id}/{record_id}"
if field_type == "multipleAttachments":
attachment_texts: list[tuple[str, str]] = []
attachment_texts: list[str] = []
for attachment in field_info:
url = attachment.get("url")
filename = attachment.get("filename", "")
@@ -126,7 +109,6 @@ class AirtableConnector(LoadConnector):
if attachment_content:
try:
file_ext = get_file_ext(filename)
attachment_id = attachment["id"]
attachment_text = extract_file_text(
BytesIO(attachment_content),
filename,
@@ -134,20 +116,7 @@ class AirtableConnector(LoadConnector):
extension=file_ext,
)
if attachment_text:
# slightly nicer loading experience if we can specify the view ID
if view_id:
attachment_link = (
f"https://airtable.com/{base_id}/{table_id}/{view_id}/{record_id}"
f"/{field_id}/{attachment_id}?blocks=hide"
)
else:
attachment_link = (
f"https://airtable.com/{base_id}/{table_id}/{record_id}"
f"/{field_id}/{attachment_id}?blocks=hide"
)
attachment_texts.append(
(f"{filename}:\n{attachment_text}", attachment_link)
)
attachment_texts.append(f"{filename}:\n{attachment_text}")
except Exception as e:
logger.warning(
f"Failed to process attachment {filename}: {str(e)}"
@@ -162,31 +131,23 @@ class AirtableConnector(LoadConnector):
combined.append(collab_name)
if collab_email:
combined.append(f"({collab_email})")
return [(" ".join(combined) if combined else str(field_info), default_link)]
return [" ".join(combined) if combined else str(field_info)]
if isinstance(field_info, list):
return [(item, default_link) for item in field_info]
return [str(item) for item in field_info]
return [(str(field_info), default_link)]
return [str(field_info)]
def _should_be_metadata(self, field_type: str) -> bool:
"""Determine if a field type should be treated as metadata.
When treat_all_non_attachment_fields_as_metadata is True, all fields except
attachments are treated as metadata. Otherwise, only fields with types listed
in DEFAULT_METADATA_FIELD_TYPES are treated as metadata."""
if self.treat_all_non_attachment_fields_as_metadata:
return field_type.lower() != "multipleattachments"
return field_type.lower() in DEFAULT_METADATA_FIELD_TYPES
"""Determine if a field type should be treated as metadata."""
return field_type.lower() in _METADATA_FIELD_TYPES
def _process_field(
self,
field_id: str,
field_name: str,
field_info: Any,
field_type: str,
table_id: str,
view_id: str | None,
record_id: str,
) -> tuple[list[Section], dict[str, Any]]:
"""
@@ -204,21 +165,12 @@ class AirtableConnector(LoadConnector):
return [], {}
# Get the value(s) for the field
field_value_and_links = self._extract_field_values(
field_id=field_id,
field_info=field_info,
field_type=field_type,
base_id=self.base_id,
table_id=table_id,
view_id=view_id,
record_id=record_id,
)
if len(field_value_and_links) == 0:
field_values = self._get_field_value(field_info, field_type)
if len(field_values) == 0:
return [], {}
# Determine if it should be metadata or a section
if self._should_be_metadata(field_type):
field_values = [value for value, _ in field_value_and_links]
if len(field_values) > 1:
return [], {field_name: field_values}
return [], {field_name: field_values[0]}
@@ -226,7 +178,7 @@ class AirtableConnector(LoadConnector):
# Otherwise, create relevant sections
sections = [
Section(
link=link,
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
text=(
f"{field_name}:\n"
"------------------------\n"
@@ -234,7 +186,7 @@ class AirtableConnector(LoadConnector):
"------------------------"
),
)
for text, link in field_value_and_links
for text in field_values
]
return sections, {}
@@ -243,7 +195,7 @@ class AirtableConnector(LoadConnector):
record: RecordDict,
table_schema: TableSchema,
primary_field_name: str | None,
) -> Document | None:
) -> Document:
"""Process a single Airtable record into a Document.
Args:
@@ -267,7 +219,6 @@ class AirtableConnector(LoadConnector):
primary_field_value = (
fields.get(primary_field_name) if primary_field_name else None
)
view_id = table_schema.views[0].id if table_schema.views else None
for field_schema in table_schema.fields:
field_name = field_schema.name
@@ -275,22 +226,16 @@ class AirtableConnector(LoadConnector):
field_type = field_schema.type
field_sections, field_metadata = self._process_field(
field_id=field_schema.id,
field_name=field_name,
field_info=field_val,
field_type=field_type,
table_id=table_id,
view_id=view_id,
record_id=record_id,
)
sections.extend(field_sections)
metadata.update(field_metadata)
if not sections:
logger.warning(f"No sections found for record {record_id}")
return None
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
@@ -334,8 +279,7 @@ class AirtableConnector(LoadConnector):
table_schema=table_schema,
primary_field_name=primary_field_name,
)
if document:
record_documents.append(document)
record_documents.append(document)
if len(record_documents) >= self.batch_size:
yield record_documents

View File

@@ -232,29 +232,20 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
}
# Get labels
label_dicts = (
confluence_object.get("metadata", {}).get("labels", {}).get("results", [])
)
page_labels = [label.get("name") for label in label_dicts if label.get("name")]
label_dicts = confluence_object["metadata"]["labels"]["results"]
page_labels = [label["name"] for label in label_dicts]
if page_labels:
doc_metadata["labels"] = page_labels
# Get last modified and author email
version_dict = confluence_object.get("version", {})
last_modified = (
datetime_from_string(version_dict.get("when"))
if version_dict.get("when")
else None
)
author_email = version_dict.get("by", {}).get("email")
title = confluence_object.get("title", "Untitled Document")
last_modified = datetime_from_string(confluence_object["version"]["when"])
author_email = confluence_object["version"].get("by", {}).get("email")
return Document(
id=object_url,
sections=[Section(link=object_url, text=object_text)],
source=DocumentSource.CONFLUENCE,
semantic_identifier=title,
semantic_identifier=confluence_object["title"],
doc_updated_at=last_modified,
primary_owners=(
[BasicExpertInfo(email=author_email)] if author_email else None

View File

@@ -50,9 +50,6 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
current_link = ""
current_text = ""
if transcript["sentences"] is None:
return None
for sentence in transcript["sentences"]:
if sentence["speaker_name"] != current_speaker_name:
if current_speaker_name is not None:

View File

@@ -1,14 +1,16 @@
import io
import os
from dataclasses import dataclass
from dataclasses import field
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import unquote
from typing import Optional
import msal # type: ignore
from office365.graph_client import GraphClient # type: ignore
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
from pydantic import BaseModel
from office365.onedrive.sites.site import Site # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
@@ -27,25 +29,16 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
class SiteDescriptor(BaseModel):
"""Data class for storing SharePoint site information.
Args:
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests)
drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library")
If None, all drives will be accessed.
folder_path: The folder path within the drive to access (e.g. "test/nested with spaces")
If None, all folders will be accessed.
"""
url: str
drive_name: str | None
folder_path: str | None
@dataclass
class SiteData:
url: str | None
folder: Optional[str]
sites: list = field(default_factory=list)
driveitems: list = field(default_factory=list)
def _convert_driveitem_to_document(
driveitem: DriveItem,
drive_name: str,
) -> Document:
file_text = extract_file_text(
file=io.BytesIO(driveitem.get_content().execute_query().value),
@@ -65,7 +58,7 @@ def _convert_driveitem_to_document(
email=driveitem.last_modified_by.user.email,
)
],
metadata={"drive": drive_name},
metadata={},
)
return doc
@@ -77,172 +70,93 @@ class SharepointConnector(LoadConnector, PollConnector):
sites: list[str] = [],
) -> None:
self.batch_size = batch_size
self._graph_client: GraphClient | None = None
self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info(
sites
)
self.msal_app: msal.ConfidentialClientApplication | None = None
@property
def graph_client(self) -> GraphClient:
if self._graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
return self._graph_client
self.graph_client: GraphClient | None = None
self.site_data: list[SiteData] = self._extract_site_and_folder(sites)
@staticmethod
def _extract_site_and_drive_info(site_urls: list[str]) -> list[SiteDescriptor]:
def _extract_site_and_folder(site_urls: list[str]) -> list[SiteData]:
site_data_list = []
for url in site_urls:
parts = url.strip().split("/")
if "sites" in parts:
sites_index = parts.index("sites")
site_url = "/".join(parts[: sites_index + 2])
remaining_parts = parts[sites_index + 2 :]
# Extract drive name and folder path
if remaining_parts:
drive_name = unquote(remaining_parts[0])
folder_path = (
"/".join(unquote(part) for part in remaining_parts[1:])
if len(remaining_parts) > 1
else None
)
else:
drive_name = None
folder_path = None
folder = (
parts[sites_index + 2] if len(parts) > sites_index + 2 else None
)
site_data_list.append(
SiteDescriptor(
url=site_url,
drive_name=drive_name,
folder_path=folder_path,
)
SiteData(url=site_url, folder=folder, sites=[], driveitems=[])
)
return site_data_list
def _fetch_driveitems(
def _populate_sitedata_driveitems(
self,
site_descriptor: SiteDescriptor,
start: datetime | None = None,
end: datetime | None = None,
) -> list[tuple[DriveItem, str]]:
) -> None:
filter_str = ""
if start is not None and end is not None:
filter_str = (
f"last_modified_datetime ge {start.isoformat()} and "
f"last_modified_datetime le {end.isoformat()}"
)
filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}"
final_driveitems: list[tuple[DriveItem, str]] = []
try:
site = self.graph_client.sites.get_by_url(site_descriptor.url)
for element in self.site_data:
sites: list[Site] = []
for site in element.sites:
site_sublist = site.lists.get().execute_query()
sites.extend(site_sublist)
# Get all drives in the site
drives = site.drives.get().execute_query()
logger.debug(f"Found drives: {[drive.name for drive in drives]}")
# Filter drives based on the requested drive name
if site_descriptor.drive_name:
drives = [
drive
for drive in drives
if drive.name == site_descriptor.drive_name
or (
drive.name == "Documents"
and site_descriptor.drive_name == "Shared Documents"
)
]
if not drives:
logger.warning(f"Drive '{site_descriptor.drive_name}' not found")
return []
# Process each matching drive
for drive in drives:
for site in sites:
try:
root_folder = drive.root
if site_descriptor.folder_path:
# If a specific folder is requested, navigate to it
for folder_part in site_descriptor.folder_path.split("/"):
root_folder = root_folder.get_by_path(folder_part)
# Get all items recursively
query = root_folder.get_files(True, 1000)
query = site.drive.root.get_files(True, 1000)
if filter_str:
query = query.filter(filter_str)
driveitems = query.execute_query()
logger.debug(
f"Found {len(driveitems)} items in drive '{drive.name}'"
)
# Use "Shared Documents" as the library name for the default "Documents" drive
drive_name = (
"Shared Documents" if drive.name == "Documents" else drive.name
)
if site_descriptor.folder_path:
# Filter items to ensure they're in the specified folder or its subfolders
# The path will be in format: /drives/{drive_id}/root:/folder/path
if element.folder:
filtered_driveitems = [
(item, drive_name)
item
for item in driveitems
if any(
path_part == site_descriptor.folder_path
or path_part.startswith(
site_descriptor.folder_path + "/"
)
for path_part in item.parent_reference.path.split(
"root:/"
)[1].split("/")
)
if element.folder in item.parent_reference.path
]
if len(filtered_driveitems) == 0:
all_paths = [
item.parent_reference.path for item in driveitems
]
logger.warning(
f"Nothing found for folder '{site_descriptor.folder_path}' "
f"in; any of valid paths: {all_paths}"
)
final_driveitems.extend(filtered_driveitems)
element.driveitems.extend(filtered_driveitems)
else:
final_driveitems.extend(
[(item, drive_name) for item in driveitems]
)
except Exception as e:
# Some drives might not be accessible
logger.warning(f"Failed to process drive: {str(e)}")
element.driveitems.extend(driveitems)
except Exception as e:
# Sites include things that do not contain drives so this fails
# but this is fine, as there are no actual documents in those
logger.warning(f"Failed to process site: {str(e)}")
except Exception:
# Sites include things that do not contain .drive.root so this fails
# but this is fine, as there are no actually documents in those
pass
return final_driveitems
def _populate_sitedata_sites(self) -> None:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
def _fetch_sites(self) -> list[SiteDescriptor]:
sites = self.graph_client.sites.get_all().execute_query()
site_descriptors = [
SiteDescriptor(
url=sites.resource_url,
drive_name=None,
folder_path=None,
)
]
return site_descriptors
if self.site_data:
for element in self.site_data:
element.sites = [
self.graph_client.sites.get_by_url(element.url)
.get()
.execute_query()
]
else:
sites = self.graph_client.sites.get_all().execute_query()
self.site_data = [
SiteData(url=None, folder=None, sites=sites, driveitems=[])
]
def _fetch_from_sharepoint(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
site_descriptors = self.site_descriptors or self._fetch_sites()
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
self._populate_sitedata_sites()
self._populate_sitedata_driveitems(start=start, end=end)
# goes over all urls, converts them into Document objects and then yields them in batches
doc_batch: list[Document] = []
for site_descriptor in site_descriptors:
driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end)
for driveitem, drive_name in driveitems:
for element in self.site_data:
for driveitem in element.driveitems:
logger.debug(f"Processing: {driveitem.web_url}")
doc_batch.append(_convert_driveitem_to_document(driveitem, drive_name))
doc_batch.append(_convert_driveitem_to_document(driveitem))
if len(doc_batch) >= self.batch_size:
yield doc_batch
@@ -254,26 +168,22 @@ class SharepointConnector(LoadConnector, PollConnector):
sp_client_secret = credentials["sp_client_secret"]
sp_directory_id = credentials["sp_directory_id"]
authority_url = f"https://login.microsoftonline.com/{sp_directory_id}"
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=sp_client_id,
client_credential=sp_client_secret,
)
def _acquire_token_func() -> dict[str, Any]:
"""
Acquire token via MSAL
"""
if self.msal_app is None:
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
authority_url = f"https://login.microsoftonline.com/{sp_directory_id}"
app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=sp_client_id,
client_credential=sp_client_secret,
)
token = app.acquire_token_for_client(
scopes=["https://graph.microsoft.com/.default"]
)
return token
self._graph_client = GraphClient(_acquire_token_func)
self.graph_client = GraphClient(_acquire_token_func)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
@@ -282,19 +192,19 @@ class SharepointConnector(LoadConnector, PollConnector):
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(start, timezone.utc)
end_datetime = datetime.fromtimestamp(end, timezone.utc)
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime)
if __name__ == "__main__":
connector = SharepointConnector(sites=os.environ["SHAREPOINT_SITES"].split(","))
connector = SharepointConnector(sites=os.environ["SITES"].split(","))
connector.load_credentials(
{
"sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
"sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
"sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
"sp_client_id": os.environ["SP_CLIENT_ID"],
"sp_client_secret": os.environ["SP_CLIENT_SECRET"],
"sp_directory_id": os.environ["SP_CLIENT_DIRECTORY_ID"],
}
)
document_batches = connector.load_from_state()

View File

@@ -104,11 +104,8 @@ def make_slack_api_rate_limited(
f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
)
time.sleep(retry_after)
elif error in ["already_reacted", "no_reaction", "internal_error"]:
# Log internal_error and return the response instead of failing
logger.warning(
f"Slack call encountered '{error}', skipping and continuing..."
)
elif error in ["already_reacted", "no_reaction"]:
# The response isn't used for reactions, this is basically just a pass
return e.response
else:
# Raise the error for non-transient errors

View File

@@ -180,28 +180,23 @@ class TeamsConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self.graph_client: GraphClient | None = None
self.requested_team_list: list[str] = teams
self.msal_app: msal.ConfidentialClientApplication | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
teams_client_id = credentials["teams_client_id"]
teams_client_secret = credentials["teams_client_secret"]
teams_directory_id = credentials["teams_directory_id"]
authority_url = f"https://login.microsoftonline.com/{teams_directory_id}"
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=teams_client_id,
client_credential=teams_client_secret,
)
def _acquire_token_func() -> dict[str, Any]:
"""
Acquire token via MSAL
"""
if self.msal_app is None:
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
authority_url = f"https://login.microsoftonline.com/{teams_directory_id}"
app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=teams_client_id,
client_credential=teams_client_secret,
)
token = app.acquire_token_for_client(
scopes=["https://graph.microsoft.com/.default"]
)
return token

View File

@@ -67,7 +67,10 @@ class SearchPipeline:
self.rerank_metrics_callback = rerank_metrics_callback
self.search_settings = get_current_search_settings(db_session)
self.document_index = get_default_document_index(self.search_settings, None)
self.document_index = get_default_document_index(
primary_index_name=self.search_settings.index_name,
secondary_index_name=None,
)
self.prompt_config: PromptConfig | None = prompt_config
# Preprocessing steps generate this

View File

@@ -28,9 +28,6 @@ class SyncType(str, PyEnum):
DOCUMENT_SET = "document_set"
USER_GROUP = "user_group"
CONNECTOR_DELETION = "connector_deletion"
PRUNING = "pruning" # not really a sync, but close enough
EXTERNAL_PERMISSIONS = "external_permissions"
EXTERNAL_GROUP = "external_group"
def __str__(self) -> str:
return self.value

View File

@@ -432,7 +432,7 @@ def get_paginated_index_attempts_for_cc_pair_id(
stmt = stmt.order_by(IndexAttempt.time_started.desc())
# Apply pagination
stmt = stmt.offset(page * page_size).limit(page_size)
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
return list(db_session.execute(stmt).scalars().all())

View File

@@ -193,13 +193,13 @@ def fetch_input_prompts_by_user(
"""
Returns all prompts belonging to the user or public prompts,
excluding those the user has specifically disabled.
Also, if `user_id` is None and AUTH_TYPE is DISABLED, then all prompts are returned.
"""
# Start with a basic query for InputPrompt
query = select(InputPrompt)
# If we have a user, left join to InputPrompt__User so we can check "disabled"
if user_id is not None:
# If we have a user, left join to InputPrompt__User to check "disabled"
IPU = aliased(InputPrompt__User)
query = query.join(
IPU,
@@ -208,30 +208,25 @@ def fetch_input_prompts_by_user(
)
# Exclude disabled prompts
# i.e. keep only those where (IPU.disabled is NULL or False)
query = query.where(or_(IPU.disabled.is_(None), IPU.disabled.is_(False)))
if include_public:
# Return both user-owned and public prompts
# user-owned or public
query = query.where(
or_(
InputPrompt.user_id == user_id,
InputPrompt.is_public,
)
(InputPrompt.user_id == user_id) | (InputPrompt.is_public)
)
else:
# Return only user-owned prompts
# only user-owned prompts
query = query.where(InputPrompt.user_id == user_id)
else:
# user_id is None
if AUTH_TYPE == AuthType.DISABLED:
# If auth is disabled, return all prompts
query = query.where(True) # type: ignore
elif include_public:
# Anonymous usage
query = query.where(InputPrompt.is_public)
# If no user is logged in, get all prompts (public and private)
if user_id is None and AUTH_TYPE == AuthType.DISABLED:
query = query.where(True) # type: ignore
# Default to returning all prompts
# If no user is logged in but we want to include public prompts
elif include_public:
query = query.where(InputPrompt.is_public)
if active is not None:
query = query.where(InputPrompt.active == active)

View File

@@ -3,8 +3,6 @@ from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.constants import AuthType
from onyx.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from onyx.db.models import DocumentSet
from onyx.db.models import LLMProvider as LLMProviderModel
@@ -126,29 +124,10 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
def fetch_existing_llm_providers(
db_session: Session,
) -> list[LLMProviderModel]:
stmt = select(LLMProviderModel)
return list(db_session.scalars(stmt).all())
def fetch_existing_llm_providers_for_user(
db_session: Session,
user: User | None = None,
) -> list[LLMProviderModel]:
if not user:
if AUTH_TYPE != AuthType.DISABLED:
# User is anonymous
return list(
db_session.scalars(
select(LLMProviderModel).where(
LLMProviderModel.is_public == True # noqa: E712
)
).all()
)
else:
# If auth is disabled, user has access to all providers
return fetch_existing_llm_providers(db_session)
return list(db_session.scalars(select(LLMProviderModel)).all())
stmt = select(LLMProviderModel).distinct()
user_groups_select = select(User__UserGroup.user_group_id).where(
User__UserGroup.user_id == user.id

View File

@@ -161,7 +161,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
hidden_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[]
)
recent_assistants: Mapped[list[dict]] = mapped_column(
postgresql.JSONB(), nullable=False, default=list, server_default="[]"
)
pinned_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)
@@ -745,34 +747,6 @@ class SearchSettings(Base):
def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider is not None else None
@property
def large_chunks_enabled(self) -> bool:
"""
Given multipass usage and an embedder, decides whether large chunks are allowed
based on model/provider constraints.
"""
# Only local models that support a larger context are from Nomic
# Cohere does not support larger contexts (they recommend not going above ~512 tokens)
return SearchSettings.can_use_large_chunks(
self.multipass_indexing, self.model_name, self.provider_type
)
@staticmethod
def can_use_large_chunks(
multipass: bool, model_name: str, provider_type: EmbeddingProvider | None
) -> bool:
"""
Given multipass usage and an embedder, decides whether large chunks are allowed
based on model/provider constraints.
"""
# Only local models that support a larger context are from Nomic
# Cohere does not support larger contexts (they recommend not going above ~512 tokens)
return (
multipass
and model_name.startswith("nomic-ai")
and provider_type != EmbeddingProvider.COHERE
)
class IndexAttempt(Base):
"""
@@ -1456,8 +1430,6 @@ class Tool(Base):
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
# whether to pass through the user's OAuth token as Authorization header
passthrough_auth: Mapped[bool] = mapped_column(Boolean, default=False)
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
# Relationship to Persona through the association table

View File

@@ -29,21 +29,9 @@ from onyx.utils.logger import setup_logger
from shared_configs.configs import PRESERVED_SEARCH_FIELDS
from shared_configs.enums import EmbeddingProvider
logger = setup_logger()
class ActiveSearchSettings:
primary: SearchSettings
secondary: SearchSettings | None
def __init__(
self, primary: SearchSettings, secondary: SearchSettings | None
) -> None:
self.primary = primary
self.secondary = secondary
def create_search_settings(
search_settings: SavedSearchSettings,
db_session: Session,
@@ -155,27 +143,21 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
return latest_settings
def get_active_search_settings(db_session: Session) -> ActiveSearchSettings:
"""Returns active search settings. Secondary search settings may be None."""
# Get the primary and secondary search settings
primary_search_settings = get_current_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
return ActiveSearchSettings(
primary=primary_search_settings, secondary=secondary_search_settings
)
def get_active_search_settings_list(db_session: Session) -> list[SearchSettings]:
"""Returns active search settings as a list. Primary settings are the first element,
and if secondary search settings exist, they will be the second element."""
def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
"""Returns active search settings. The first entry will always be the current search
settings. If there are new search settings that are being migrated to, those will be
the second entry."""
search_settings_list: list[SearchSettings] = []
active_search_settings = get_active_search_settings(db_session)
search_settings_list.append(active_search_settings.primary)
if active_search_settings.secondary:
search_settings_list.append(active_search_settings.secondary)
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings_list.append(primary_search_settings)
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings_list.append(secondary_search_settings)
return search_settings_list

View File

@@ -15,7 +15,6 @@ from onyx.db.models import User
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import upsert_persona
from onyx.db.prompts import get_default_prompt
from onyx.tools.built_in_tools import get_search_tool
from onyx.utils.errors import EERequiredError
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
@@ -48,10 +47,6 @@ def create_slack_channel_persona(
) -> Persona:
"""NOTE: does not commit changes"""
search_tool = get_search_tool(db_session)
if search_tool is None:
raise ValueError("Search tool not found")
# create/update persona associated with the Slack channel
persona_name = _build_persona_name(channel_name)
default_prompt = get_default_prompt(db_session)
@@ -65,7 +60,6 @@ def create_slack_channel_persona(
llm_filter_extraction=enable_auto_filters,
recency_bias=RecencyBiasSetting.AUTO,
prompt_ids=[default_prompt.id],
tool_ids=[search_tool.id],
document_set_ids=document_set_ids,
llm_model_provider_override=None,
llm_model_version_override=None,

View File

@@ -8,64 +8,20 @@ from sqlalchemy.orm import Session
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import SyncRecord
from onyx.setup import setup_logger
logger = setup_logger()
def insert_sync_record(
db_session: Session,
entity_id: int,
entity_id: int | None,
sync_type: SyncType,
) -> SyncRecord:
"""Insert a new sync record into the database, cancelling any existing in-progress records.
"""Insert a new sync record into the database.
Args:
db_session: The database session to use
entity_id: The ID of the entity being synced (document set ID, user group ID, etc.)
sync_type: The type of sync operation
"""
# If an existing in-progress sync record exists, mark as cancelled
existing_in_progress_sync_record = fetch_latest_sync_record(
db_session, entity_id, sync_type, sync_status=SyncStatus.IN_PROGRESS
)
if existing_in_progress_sync_record is not None:
logger.info(
f"Cancelling existing in-progress sync record {existing_in_progress_sync_record.id} "
f"for entity_id={entity_id} sync_type={sync_type}"
)
mark_sync_records_as_cancelled(db_session, entity_id, sync_type)
return _create_sync_record(db_session, entity_id, sync_type)
def mark_sync_records_as_cancelled(
db_session: Session,
entity_id: int | None,
sync_type: SyncType,
) -> None:
stmt = (
update(SyncRecord)
.where(
and_(
SyncRecord.entity_id == entity_id,
SyncRecord.sync_type == sync_type,
SyncRecord.sync_status == SyncStatus.IN_PROGRESS,
)
)
.values(sync_status=SyncStatus.CANCELED)
)
db_session.execute(stmt)
db_session.commit()
def _create_sync_record(
db_session: Session,
entity_id: int | None,
sync_type: SyncType,
) -> SyncRecord:
"""Create and insert a new sync record into the database."""
sync_record = SyncRecord(
entity_id=entity_id,
sync_type=sync_type,
@@ -83,7 +39,6 @@ def fetch_latest_sync_record(
db_session: Session,
entity_id: int,
sync_type: SyncType,
sync_status: SyncStatus | None = None,
) -> SyncRecord | None:
"""Fetch the most recent sync record for a given entity ID and status.
@@ -104,9 +59,6 @@ def fetch_latest_sync_record(
.limit(1)
)
if sync_status is not None:
stmt = stmt.where(SyncRecord.sync_status == sync_status)
result = db_session.execute(stmt)
return result.scalar_one_or_none()

View File

@@ -38,7 +38,6 @@ def create_tool(
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
passthrough_auth: bool,
) -> Tool:
new_tool = Tool(
name=name,
@@ -49,7 +48,6 @@ def create_tool(
if custom_headers
else [],
user_id=user_id,
passthrough_auth=passthrough_auth,
)
db_session.add(new_tool)
db_session.commit()
@@ -64,7 +62,6 @@ def update_tool(
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
passthrough_auth: bool | None,
) -> Tool:
tool = get_tool_by_id(tool_id, db_session)
if tool is None:
@@ -82,8 +79,6 @@ def update_tool(
tool.custom_headers = [
cast(HeaderItemDict, header.model_dump()) for header in custom_headers
]
if passthrough_auth is not None:
tool.passthrough_auth = passthrough_auth
db_session.commit()
return tool

View File

@@ -4,63 +4,24 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.configs.app_configs import ENABLE_MULTIPASS_INDEXING
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import MultipassConfig
from shared_configs.configs import MULTI_TENANT
DEFAULT_BATCH_SIZE = 30
DEFAULT_INDEX_NAME = "danswer_chunk"
def should_use_multipass(search_settings: SearchSettings | None) -> bool:
"""
Determines whether multipass should be used based on the search settings
or the default config if settings are unavailable.
"""
if search_settings is not None:
return search_settings.multipass_indexing
return ENABLE_MULTIPASS_INDEXING
def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig:
"""
Determines whether to enable multipass and large chunks by examining
the current search settings and the embedder configuration.
"""
if not search_settings:
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
multipass = should_use_multipass(search_settings)
enable_large_chunks = SearchSettings.can_use_large_chunks(
multipass, search_settings.model_name, search_settings.provider_type
)
return MultipassConfig(
multipass_indexing=multipass, enable_large_chunks=enable_large_chunks
)
def get_both_index_properties(
db_session: Session,
) -> tuple[str, str | None, bool, bool | None]:
def get_both_index_names(db_session: Session) -> tuple[str, str | None]:
search_settings = get_current_search_settings(db_session)
config_1 = get_multipass_config(search_settings)
search_settings_new = get_secondary_search_settings(db_session)
if not search_settings_new:
return search_settings.index_name, None, config_1.enable_large_chunks, None
return search_settings.index_name, None
config_2 = get_multipass_config(search_settings)
return (
search_settings.index_name,
search_settings_new.index_name,
config_1.enable_large_chunks,
config_2.enable_large_chunks,
)
return search_settings.index_name, search_settings_new.index_name
def translate_boost_count_to_multiplier(boost: int) -> float:

View File

@@ -1,7 +1,5 @@
import httpx
from sqlalchemy.orm import Session
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.vespa.index import VespaIndex
@@ -9,28 +7,17 @@ from shared_configs.configs import MULTI_TENANT
def get_default_document_index(
search_settings: SearchSettings,
secondary_search_settings: SearchSettings | None,
httpx_client: httpx.Client | None = None,
primary_index_name: str,
secondary_index_name: str | None,
) -> DocumentIndex:
"""Primary index is the index that is used for querying/updating etc.
Secondary index is for when both the currently used index and the upcoming
index both need to be updated, updates are applied to both indices"""
secondary_index_name: str | None = None
secondary_large_chunks_enabled: bool | None = None
if secondary_search_settings:
secondary_index_name = secondary_search_settings.index_name
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
# Currently only supporting Vespa
return VespaIndex(
index_name=search_settings.index_name,
index_name=primary_index_name,
secondary_index_name=secondary_index_name,
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)
@@ -40,6 +27,6 @@ def get_current_primary_default_document_index(db_session: Session) -> DocumentI
"""
search_settings = get_current_search_settings(db_session)
return get_default_document_index(
search_settings,
None,
primary_index_name=search_settings.index_name,
secondary_index_name=None,
)

View File

@@ -231,22 +231,21 @@ def _get_chunks_via_visit_api(
return document_chunks
# TODO(rkuo): candidate for removal if not being used
# @retry(tries=10, delay=1, backoff=2)
# def get_all_vespa_ids_for_document_id(
# document_id: str,
# index_name: str,
# filters: IndexFilters | None = None,
# get_large_chunks: bool = False,
# ) -> list[str]:
# document_chunks = _get_chunks_via_visit_api(
# chunk_request=VespaChunkRequest(document_id=document_id),
# index_name=index_name,
# filters=filters or IndexFilters(access_control_list=None),
# field_names=[DOCUMENT_ID],
# get_large_chunks=get_large_chunks,
# )
# return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks]
@retry(tries=10, delay=1, backoff=2)
def get_all_vespa_ids_for_document_id(
document_id: str,
index_name: str,
filters: IndexFilters | None = None,
get_large_chunks: bool = False,
) -> list[str]:
document_chunks = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(document_id=document_id),
index_name=index_name,
filters=filters or IndexFilters(access_control_list=None),
field_names=[DOCUMENT_ID],
get_large_chunks=get_large_chunks,
)
return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks]
def parallel_visit_api_retrieval(

View File

@@ -25,6 +25,7 @@ from onyx.configs.chat_configs import VESPA_SEARCHER_THREADS
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.db.engine import get_session_with_tenant
from onyx.document_index.document_index_utils import get_document_chunk_ids
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentInsertionRecord
@@ -40,12 +41,12 @@ from onyx.document_index.vespa.chunk_retrieval import (
)
from onyx.document_index.vespa.chunk_retrieval import query_vespa
from onyx.document_index.vespa.deletion import delete_vespa_chunks
from onyx.document_index.vespa.indexing_utils import BaseHTTPXClientContext
from onyx.document_index.vespa.indexing_utils import batch_index_vespa_chunks
from onyx.document_index.vespa.indexing_utils import check_for_final_chunk_existence
from onyx.document_index.vespa.indexing_utils import clean_chunk_id_copy
from onyx.document_index.vespa.indexing_utils import GlobalHTTPXClientContext
from onyx.document_index.vespa.indexing_utils import TemporaryHTTPXClientContext
from onyx.document_index.vespa.indexing_utils import (
get_multipass_config,
)
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
@@ -131,34 +132,12 @@ class VespaIndex(DocumentIndex):
self,
index_name: str,
secondary_index_name: str | None,
large_chunks_enabled: bool,
secondary_large_chunks_enabled: bool | None,
multitenant: bool = False,
httpx_client: httpx.Client | None = None,
) -> None:
self.index_name = index_name
self.secondary_index_name = secondary_index_name
self.large_chunks_enabled = large_chunks_enabled
self.secondary_large_chunks_enabled = secondary_large_chunks_enabled
self.multitenant = multitenant
self.httpx_client_context: BaseHTTPXClientContext
if httpx_client:
self.httpx_client_context = GlobalHTTPXClientContext(httpx_client)
else:
self.httpx_client_context = TemporaryHTTPXClientContext(
get_vespa_http_client
)
self.index_to_large_chunks_enabled: dict[str, bool] = {}
self.index_to_large_chunks_enabled[index_name] = large_chunks_enabled
if secondary_index_name and secondary_large_chunks_enabled:
self.index_to_large_chunks_enabled[
secondary_index_name
] = secondary_large_chunks_enabled
self.http_client = get_vespa_http_client()
def ensure_indices_exist(
self,
@@ -352,7 +331,7 @@ class VespaIndex(DocumentIndex):
# indexing / updates / deletes since we have to make a large volume of requests.
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
self.httpx_client_context as http_client,
get_vespa_http_client() as http_client,
):
# We require the start and end index for each document in order to
# know precisely which chunks to delete. This information exists for
@@ -411,11 +390,9 @@ class VespaIndex(DocumentIndex):
for doc_id in all_doc_ids
}
@classmethod
@staticmethod
def _apply_updates_batched(
cls,
updates: list[_VespaUpdateRequest],
httpx_client: httpx.Client,
batch_size: int = BATCH_SIZE,
) -> None:
"""Runs a batch of updates in parallel via the ThreadPoolExecutor."""
@@ -437,7 +414,7 @@ class VespaIndex(DocumentIndex):
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
httpx_client as http_client,
get_vespa_http_client() as http_client,
):
for update_batch in batch_generator(updates, batch_size):
future_to_document_id = {
@@ -478,7 +455,7 @@ class VespaIndex(DocumentIndex):
index_names.append(self.secondary_index_name)
chunk_id_start_time = time.monotonic()
with self.httpx_client_context as http_client:
with get_vespa_http_client() as http_client:
for update_request in update_requests:
for doc_info in update_request.minimal_document_indexing_info:
for index_name in index_names:
@@ -534,8 +511,7 @@ class VespaIndex(DocumentIndex):
)
)
with self.httpx_client_context as httpx_client:
self._apply_updates_batched(processed_updates_requests, httpx_client)
self._apply_updates_batched(processed_updates_requests)
logger.debug(
"Finished updating Vespa documents in %.2f seconds",
time.monotonic() - update_start,
@@ -547,7 +523,6 @@ class VespaIndex(DocumentIndex):
index_name: str,
fields: VespaDocumentFields,
doc_id: str,
http_client: httpx.Client,
) -> None:
"""
Update a single "chunk" (document) in Vespa using its chunk ID.
@@ -579,17 +554,18 @@ class VespaIndex(DocumentIndex):
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}?create=true"
try:
resp = http_client.put(
vespa_url,
headers={"Content-Type": "application/json"},
json=update_dict,
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
error_message = f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). Details: {e.response.text}"
logger.error(error_message)
raise
with get_vespa_http_client(http2=False) as http_client:
try:
resp = http_client.put(
vespa_url,
headers={"Content-Type": "application/json"},
json=update_dict,
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
error_message = f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). Details: {e.response.text}"
logger.error(error_message)
raise
def update_single(
self,
@@ -603,16 +579,24 @@ class VespaIndex(DocumentIndex):
function will complete with no errors or exceptions.
Handle other exceptions if you wish to implement retry behavior
"""
doc_chunk_count = 0
with self.httpx_client_context as httpx_client:
for (
index_name,
large_chunks_enabled,
) in self.index_to_large_chunks_enabled.items():
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with get_vespa_http_client(http2=False) as http_client:
for index_name in index_names:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
multipass_config = get_multipass_config(
db_session=db_session,
primary_index=index_name == self.index_name,
)
large_chunks_enabled = multipass_config.enable_large_chunks
enriched_doc_infos = VespaIndex.enrich_basic_chunk_info(
index_name=index_name,
http_client=httpx_client,
http_client=http_client,
document_id=doc_id,
previous_chunk_count=chunk_count,
new_chunk_count=0,
@@ -628,7 +612,10 @@ class VespaIndex(DocumentIndex):
for doc_chunk_id in doc_chunk_ids:
self.update_single_chunk(
doc_chunk_id, index_name, fields, doc_id, httpx_client
doc_chunk_id=doc_chunk_id,
index_name=index_name,
fields=fields,
doc_id=doc_id,
)
return doc_chunk_count
@@ -650,13 +637,19 @@ class VespaIndex(DocumentIndex):
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with self.httpx_client_context as http_client, concurrent.futures.ThreadPoolExecutor(
with get_vespa_http_client(
http2=False
) as http_client, concurrent.futures.ThreadPoolExecutor(
max_workers=NUM_THREADS
) as executor:
for (
index_name,
large_chunks_enabled,
) in self.index_to_large_chunks_enabled.items():
for index_name in index_names:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
multipass_config = get_multipass_config(
db_session=db_session,
primary_index=index_name == self.index_name,
)
large_chunks_enabled = multipass_config.enable_large_chunks
enriched_doc_infos = VespaIndex.enrich_basic_chunk_info(
index_name=index_name,
http_client=http_client,
@@ -825,9 +818,6 @@ class VespaIndex(DocumentIndex):
"""
Deletes all entries in the specified index with the given tenant_id.
Currently unused, but we anticipate this being useful. The entire flow does not
use the httpx connection pool of an instance.
Parameters:
tenant_id (str): The tenant ID whose documents are to be deleted.
index_name (str): The name of the index from which to delete documents.
@@ -860,8 +850,6 @@ class VespaIndex(DocumentIndex):
"""
Retrieves all document IDs with the specified tenant_id, handling pagination.
Internal helper function for delete_entries_by_tenant_id.
Parameters:
tenant_id (str): The tenant ID to search for.
index_name (str): The name of the index to search in.
@@ -894,8 +882,8 @@ class VespaIndex(DocumentIndex):
f"Querying for document IDs with tenant_id: {tenant_id}, offset: {offset}"
)
with get_vespa_http_client() as http_client:
response = http_client.get(url, params=query_params, timeout=None)
with get_vespa_http_client(no_timeout=True) as http_client:
response = http_client.get(url, params=query_params)
response.raise_for_status()
search_result = response.json()
@@ -925,11 +913,6 @@ class VespaIndex(DocumentIndex):
"""
Deletes documents in batches using multiple threads.
Internal helper function for delete_entries_by_tenant_id.
This is a class method and does not use the httpx pool of the instance.
This is OK because we don't use this method often.
Parameters:
delete_requests (List[_VespaDeleteRequest]): The list of delete requests.
batch_size (int): The number of documents to delete in each batch.
@@ -942,14 +925,13 @@ class VespaIndex(DocumentIndex):
response = http_client.delete(
delete_request.url,
headers={"Content-Type": "application/json"},
timeout=None,
)
response.raise_for_status()
logger.debug(f"Starting batch deletion for {len(delete_requests)} documents")
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
with get_vespa_http_client() as http_client:
with get_vespa_http_client(no_timeout=True) as http_client:
for batch_start in range(0, len(delete_requests), batch_size):
batch = delete_requests[batch_start : batch_start + batch_size]

View File

@@ -1,19 +1,21 @@
import concurrent.futures
import json
import uuid
from abc import ABC
from abc import abstractmethod
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
import httpx
from retry import retry
from sqlalchemy.orm import Session
from onyx.configs.app_configs import ENABLE_MULTIPASS_INDEXING
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
)
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.document_index.document_index_utils import get_uuid_from_chunk
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info_old
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
@@ -48,9 +50,10 @@ from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import TITLE
from onyx.document_index.vespa_constants import TITLE_EMBEDDING
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import EmbeddingProvider
from onyx.indexing.models import MultipassConfig
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -272,42 +275,46 @@ def check_for_final_chunk_existence(
index += 1
class BaseHTTPXClientContext(ABC):
"""Abstract base class for an HTTPX client context manager."""
@abstractmethod
def __enter__(self) -> httpx.Client:
pass
@abstractmethod
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
pass
def should_use_multipass(search_settings: SearchSettings | None) -> bool:
"""
Determines whether multipass should be used based on the search settings
or the default config if settings are unavailable.
"""
if search_settings is not None:
return search_settings.multipass_indexing
return ENABLE_MULTIPASS_INDEXING
class GlobalHTTPXClientContext(BaseHTTPXClientContext):
"""Context manager for a global HTTPX client that does not close it."""
def __init__(self, client: httpx.Client):
self._client = client
def __enter__(self) -> httpx.Client:
return self._client # Reuse the global client
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
pass # Do nothing; don't close the global client
def can_use_large_chunks(multipass: bool, search_settings: SearchSettings) -> bool:
"""
Given multipass usage and an embedder, decides whether large chunks are allowed
based on model/provider constraints.
"""
# Only local models that support a larger context are from Nomic
# Cohere does not support larger contexts (they recommend not going above ~512 tokens)
return (
multipass
and search_settings.model_name.startswith("nomic-ai")
and search_settings.provider_type != EmbeddingProvider.COHERE
)
class TemporaryHTTPXClientContext(BaseHTTPXClientContext):
"""Context manager for a temporary HTTPX client that closes it after use."""
def __init__(self, client_factory: Callable[[], httpx.Client]):
self._client_factory = client_factory
self._client: httpx.Client | None = None # Client will be created in __enter__
def __enter__(self) -> httpx.Client:
self._client = self._client_factory() # Create a new client
return self._client
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
if self._client:
self._client.close()
def get_multipass_config(
db_session: Session, primary_index: bool = True
) -> MultipassConfig:
"""
Determines whether to enable multipass and large chunks by examining
the current search settings and the embedder configuration.
"""
search_settings = (
get_current_search_settings(db_session)
if primary_index
else get_secondary_search_settings(db_session)
)
multipass = should_use_multipass(search_settings)
if not search_settings:
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
enable_large_chunks = can_use_large_chunks(multipass, search_settings)
return MultipassConfig(
multipass_indexing=multipass, enable_large_chunks=enable_large_chunks
)

View File

@@ -1,5 +1,4 @@
import re
import time
from typing import cast
import httpx
@@ -8,10 +7,6 @@ from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.document_index.vespa_constants import VESPA_APP_CONTAINER_URL
from onyx.utils.logger import setup_logger
logger = setup_logger()
# NOTE: This does not seem to be used in reality despite the Vespa Docs pointing to this code
# See here for reference: https://docs.vespa.ai/en/documents.html
@@ -55,7 +50,7 @@ def remove_invalid_unicode_chars(text: str) -> str:
"""Vespa does not take in unicode chars that aren't valid for XML.
This removes them."""
_illegal_xml_chars_RE: re.Pattern = re.compile(
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFDD0-\uFDEF\uFFFE\uFFFF]"
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFFFE\uFFFF]"
)
return _illegal_xml_chars_RE.sub("", text)
@@ -74,37 +69,3 @@ def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
http2=http2,
)
def wait_for_vespa_with_timeout(wait_interval: int = 5, wait_limit: int = 60) -> bool:
"""Waits for Vespa to become ready subject to a timeout.
Returns True if Vespa is ready, False otherwise."""
time_start = time.monotonic()
logger.info("Vespa: Readiness probe starting.")
while True:
try:
client = get_vespa_http_client()
response = client.get(f"{VESPA_APP_CONTAINER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()
if response_dict["status"]["code"] == "up":
logger.info("Vespa: Readiness probe succeeded. Continuing...")
return True
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > wait_limit:
logger.info(
f"Vespa: Readiness probe did not succeed within the timeout "
f"({wait_limit} seconds)."
)
return False
logger.info(
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit:.1f}"
)
time.sleep(wait_interval)

View File

@@ -358,13 +358,7 @@ def extract_file_text(
try:
if get_unstructured_api_key():
try:
return unstructured_to_text(file, file_name)
except Exception as unstructured_error:
logger.error(
f"Failed to process with Unstructured: {str(unstructured_error)}. Falling back to normal processing."
)
# Fall through to normal processing
return unstructured_to_text(file, file_name)
if file_name or extension:
if extension is not None:

View File

@@ -52,7 +52,7 @@ def _sdk_partition_request(
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
logger.debug(f"Starting to read file: {file_name}")
req = _sdk_partition_request(file, file_name, strategy="fast")
req = _sdk_partition_request(file, file_name, strategy="auto")
unstructured_client = UnstructuredClient(api_key_auth=get_unstructured_api_key())

View File

@@ -1,57 +0,0 @@
import threading
from typing import Any
import httpx
class HttpxPool:
"""Class to manage a global httpx Client instance"""
_clients: dict[str, httpx.Client] = {}
_lock: threading.Lock = threading.Lock()
# Default parameters for creation
DEFAULT_KWARGS = {
"http2": True,
"limits": lambda: httpx.Limits(),
}
def __init__(self) -> None:
pass
@classmethod
def _init_client(cls, **kwargs: Any) -> httpx.Client:
"""Private helper method to create and return an httpx.Client."""
merged_kwargs = {**cls.DEFAULT_KWARGS, **kwargs}
return httpx.Client(**merged_kwargs)
@classmethod
def init_client(cls, name: str, **kwargs: Any) -> None:
"""Allow the caller to init the client with extra params."""
with cls._lock:
if name not in cls._clients:
cls._clients[name] = cls._init_client(**kwargs)
@classmethod
def close_client(cls, name: str) -> None:
"""Allow the caller to close the client."""
with cls._lock:
client = cls._clients.pop(name, None)
if client:
client.close()
@classmethod
def close_all(cls) -> None:
"""Close all registered clients."""
with cls._lock:
for client in cls._clients.values():
client.close()
cls._clients.clear()
@classmethod
def get(cls, name: str) -> httpx.Client:
"""Gets the httpx.Client. Will init to default settings if not init'd."""
with cls._lock:
if name not in cls._clients:
cls._clients[name] = cls._init_client()
return cls._clients[name]

View File

@@ -31,15 +31,14 @@ from onyx.db.document import upsert_documents
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.models import Document as DBDocument
from onyx.db.search_settings import get_current_search_settings
from onyx.db.tag import create_or_add_document_tag
from onyx.db.tag import create_or_add_document_tag_list
from onyx.document_index.document_index_utils import (
get_multipass_config,
)
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.document_index.vespa.indexing_utils import (
get_multipass_config,
)
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -358,6 +357,7 @@ def index_doc_batch(
is_public=False,
)
logger.debug("Filtering Documents")
filtered_documents = filter_fnc(document_batch)
ctx = index_doc_batch_prepare(
@@ -527,8 +527,7 @@ def build_indexing_pipeline(
callback: IndexingHeartbeatInterface | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
search_settings = get_current_search_settings(db_session)
multipass_config = get_multipass_config(search_settings)
multipass_config = get_multipass_config(db_session, primary_index=True)
chunker = chunker or Chunker(
tokenizer=embedder.embedding_model.tokenizer,

View File

@@ -55,7 +55,9 @@ class DocAwareChunk(BaseChunk):
def to_short_descriptor(self) -> str:
"""Used when logging the identity of a chunk"""
return f"{self.source_document.to_short_descriptor()} Chunk ID: {self.chunk_id}"
return (
f"Chunk ID: '{self.chunk_id}'; {self.source_document.to_short_descriptor()}"
)
class IndexChunk(DocAwareChunk):

View File

@@ -2,7 +2,7 @@ from onyx.key_value_store.interface import KeyValueStore
from onyx.key_value_store.store import PgRedisKVStore
def get_kv_store(tenant_id: str | None = None) -> KeyValueStore:
def get_kv_store() -> KeyValueStore:
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
# It's read from the global thread level variable
return PgRedisKVStore(tenant_id=tenant_id)
return PgRedisKVStore()

View File

@@ -31,27 +31,27 @@ class PgRedisKVStore(KeyValueStore):
def __init__(
self, redis_client: Redis | None = None, tenant_id: str | None = None
) -> None:
self.tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
# If no redis_client is provided, fall back to the context var
if redis_client is not None:
self.redis_client = redis_client
else:
self.redis_client = get_redis_client(tenant_id=self.tenant_id)
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
self.redis_client = get_redis_client(tenant_id=tenant_id)
@contextmanager
def _get_session(self) -> Iterator[Session]:
def get_session(self) -> Iterator[Session]:
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
if self.tenant_id == POSTGRES_DEFAULT_SCHEMA:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id == POSTGRES_DEFAULT_SCHEMA:
raise HTTPException(
status_code=401, detail="User must authenticate"
)
if not is_valid_schema_name(self.tenant_id):
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{self.tenant_id}"'))
session.execute(text(f'SET search_path = "{tenant_id}"'))
yield session
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
@@ -66,7 +66,7 @@ class PgRedisKVStore(KeyValueStore):
encrypted_val = val if encrypt else None
plain_val = val if not encrypt else None
with self._get_session() as session:
with self.get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if obj:
obj.value = plain_val
@@ -88,7 +88,7 @@ class PgRedisKVStore(KeyValueStore):
except Exception as e:
logger.error(f"Failed to get value from Redis for key '{key}': {str(e)}")
with self._get_session() as session:
with self.get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if not obj:
raise KvKeyNotFoundError
@@ -113,7 +113,7 @@ class PgRedisKVStore(KeyValueStore):
except Exception as e:
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
with self._get_session() as session:
with self.get_session() as session:
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
if result == 0:
raise KvKeyNotFoundError

View File

@@ -275,22 +275,17 @@ class DefaultMultiLLM(LLM):
# addtional kwargs (and some kwargs MUST be passed in rather than set as
# env variables)
if custom_config:
# Specifically pass in "vertex_credentials" / "vertex_location" as a
# model_kwarg to the completion call for vertex AI. More details here:
# Specifically pass in "vertex_credentials" as a model_kwarg to the
# completion call for vertex AI. More details here:
# https://docs.litellm.ai/docs/providers/vertex
vertex_credentials_key = "vertex_credentials"
vertex_location_key = "vertex_location"
for k, v in custom_config.items():
if model_provider == "vertex_ai":
if k == vertex_credentials_key:
model_kwargs[k] = v
continue
elif k == vertex_location_key:
model_kwargs[k] = v
continue
# for all values, set them as env variables
os.environ[k] = v
vertex_credentials = custom_config.get(vertex_credentials_key)
if vertex_credentials and model_provider == "vertex_ai":
model_kwargs[vertex_credentials_key] = vertex_credentials
else:
# standard case
for k, v in custom_config.items():
os.environ[k] = v
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})

View File

@@ -142,7 +142,6 @@ def build_content_with_imgs(
img_urls: list[str] | None = None,
b64_imgs: list[str] | None = None,
message_type: MessageType = MessageType.USER,
exclude_images: bool = False,
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
files = files or []
@@ -158,7 +157,7 @@ def build_content_with_imgs(
message_main_content = _build_content(message, files)
if exclude_images or (not img_files and not img_urls):
if not img_files and not img_urls:
return message_main_content
return cast(
@@ -383,19 +382,9 @@ def _strip_colon_from_model_name(model_name: str) -> str:
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | None:
stripped_model_name = _strip_extra_provider_from_model_name(model_name)
model_names = [
model_name,
_strip_extra_provider_from_model_name(model_name),
# Remove leading extra provider. Usually for cases where user has a
# customer model proxy which appends another prefix
# remove :XXXX from the end, if present. Needed for ollama.
_strip_colon_from_model_name(model_name),
_strip_colon_from_model_name(stripped_model_name),
]
def _find_model_obj(
model_map: dict, provider: str, model_names: list[str | None]
) -> dict | None:
# Filter out None values and deduplicate model names
filtered_model_names = [name for name in model_names if name]
@@ -428,10 +417,21 @@ def get_llm_max_tokens(
return GEN_AI_MAX_TOKENS
try:
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
model_name
)
model_obj = _find_model_obj(
model_map,
model_provider,
model_name,
[
model_name,
# Remove leading extra provider. Usually for cases where user has a
# customer model proxy which appends another prefix
extra_provider_stripped_model_name,
# remove :XXXX from the end, if present. Needed for ollama.
_strip_colon_from_model_name(model_name),
_strip_colon_from_model_name(extra_provider_stripped_model_name),
],
)
if not model_obj:
raise RuntimeError(
@@ -523,23 +523,3 @@ def get_max_input_tokens(
raise RuntimeError("No tokens for input for the LLM given settings")
return input_toks
def model_supports_image_input(model_name: str, model_provider: str) -> bool:
model_map = get_model_map()
try:
model_obj = _find_model_obj(
model_map,
model_provider,
model_name,
)
if not model_obj:
raise RuntimeError(
f"No litellm entry found for {model_provider}/{model_name}"
)
return model_obj.get("supports_vision", False)
except Exception:
logger.exception(
f"Failed to get model object for {model_provider}/{model_name}"
)
return False

View File

@@ -212,7 +212,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
if not MULTI_TENANT:
# We cache this at the beginning so there is no delay in the first telemetry
get_or_generate_uuid(tenant_id=None)
get_or_generate_uuid()
# If we are multi-tenant, we need to only set up initial public tables
with Session(engine) as db_session:

View File

@@ -14,7 +14,6 @@ from typing import Set
from prometheus_client import Gauge
from prometheus_client import start_http_server
from redis.lock import Lock
from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
@@ -123,9 +122,6 @@ class SlackbotHandler:
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}
# Store Redis lock objects here so we can release them properly
self.redis_locks: Dict[str | None, Lock] = {}
self.running = True
self.pod_id = self.get_pod_id()
self._shutdown_event = Event()
@@ -163,15 +159,10 @@ class SlackbotHandler:
while not self._shutdown_event.is_set():
try:
self.acquire_tenants()
# After we finish acquiring and managing Slack bots,
# set the gauge to the number of active tenants (those with Slack bots).
active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set(
len(self.tenant_ids)
)
logger.debug(
f"Current active tenants with Slack bots: {len(self.tenant_ids)}"
)
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
except Exception as e:
logger.exception(f"Error in Slack acquisition: {e}")
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
@@ -180,9 +171,7 @@ class SlackbotHandler:
while not self._shutdown_event.is_set():
try:
self.send_heartbeats()
logger.debug(
f"Sent heartbeats for {len(self.tenant_ids)} active tenants"
)
logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants")
except Exception as e:
logger.exception(f"Error in heartbeat loop: {e}")
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
@@ -190,21 +179,17 @@ class SlackbotHandler:
def _manage_clients_per_tenant(
self, db_session: Session, tenant_id: str | None, bot: SlackBot
) -> None:
"""
- If the tokens are missing or empty, close the socket client and remove them.
- If the tokens have changed, close the existing socket client and reconnect.
- If the tokens are new, warm up the model and start a new socket client.
"""
slack_bot_tokens = SlackBotTokens(
bot_token=bot.bot_token,
app_token=bot.app_token,
)
tenant_bot_pair = (tenant_id, bot.id)
# If the tokens are missing or empty, close the socket client and remove them.
# If the tokens are not set, we need to close the socket client and delete the tokens
# for the tenant and app
if not slack_bot_tokens:
logger.debug(
f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}"
f"No Slack bot token found for tenant {tenant_id}, bot {bot.id}"
)
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
@@ -219,10 +204,9 @@ class SlackbotHandler:
if not tokens_exist or tokens_changed:
if tokens_exist:
logger.info(
f"Slack Bot tokens changed for tenant={tenant_id}, bot {bot.id}; reconnecting"
f"Slack Bot tokens have changed for tenant {tenant_id}, bot {bot.id} - reconnecting"
)
else:
# Warm up the model if needed
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
@@ -233,168 +217,77 @@ class SlackbotHandler:
self.slack_bot_tokens[tenant_bot_pair] = slack_bot_tokens
# Close any existing connection first
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
self.start_socket_client(bot.id, tenant_id, slack_bot_tokens)
def acquire_tenants(self) -> None:
"""
- Attempt to acquire a Redis lock for each tenant.
- If acquired, check if that tenant actually has Slack bots.
- If yes, store them in self.tenant_ids and manage the socket connections.
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
"""
all_tenants = get_all_tenant_ids()
tenant_ids = get_all_tenant_ids()
# 1) Try to acquire locks for new tenants
for tenant_id in all_tenants:
for tenant_id in tenant_ids:
if (
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
):
logger.debug(f"Tenant {tenant_id} is disallowed; skipping.")
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
continue
# Already acquired in a previous loop iteration?
if tenant_id in self.tenant_ids:
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
continue
# Respect max tenant limit per pod
if len(self.tenant_ids) >= MAX_TENANTS_PER_POD:
logger.info(
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}); not acquiring more."
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}) Not acquiring any more tenants"
)
break
redis_client = get_redis_client(tenant_id=tenant_id)
# Acquire a Redis lock (non-blocking)
rlock = redis_client.lock(
OnyxRedisLocks.SLACK_BOT_LOCK, timeout=TENANT_LOCK_EXPIRATION
pod_id = self.pod_id
acquired = redis_client.set(
OnyxRedisLocks.SLACK_BOT_LOCK,
pod_id,
nx=True,
ex=TENANT_LOCK_EXPIRATION,
)
lock_acquired = rlock.acquire(blocking=False)
if not lock_acquired and not DEV_MODE:
logger.debug(
f"Another pod holds the lock for tenant {tenant_id}, skipping."
)
if not acquired and not DEV_MODE:
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
continue
if lock_acquired:
logger.debug(f"Acquired lock for tenant {tenant_id}.")
self.redis_locks[tenant_id] = rlock
else:
# DEV_MODE will skip the lock acquisition guard
logger.debug(
f"Running in DEV_MODE. Not enforcing lock for {tenant_id}."
)
logger.debug(f"Acquired lock for tenant {tenant_id}")
# Now check if this tenant actually has Slack bots
self.tenant_ids.add(tenant_id)
for tenant_id in self.tenant_ids:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
try:
with get_session_with_tenant(tenant_id) as db_session:
bots: list[SlackBot] = []
try:
bots = list(fetch_slack_bots(db_session=db_session))
except KvKeyNotFoundError:
# No Slackbot tokens, pass
pass
except Exception as e:
logger.exception(
f"Error fetching Slack bots for tenant {tenant_id}: {e}"
)
if bots:
# Mark as active tenant
self.tenant_ids.add(tenant_id)
bots = fetch_slack_bots(db_session=db_session)
for bot in bots:
self._manage_clients_per_tenant(
db_session=db_session,
tenant_id=tenant_id,
bot=bot,
)
else:
# If no Slack bots, release lock immediately (unless in DEV_MODE)
if lock_acquired and not DEV_MODE:
rlock.release()
del self.redis_locks[tenant_id]
logger.debug(
f"No Slack bots for tenant {tenant_id}; lock released (if held)."
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
# 2) Make sure tenants we're handling still have Slack bots
for tenant_id in list(self.tenant_ids):
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
redis_client = get_redis_client(tenant_id=tenant_id)
try:
with get_session_with_tenant(tenant_id) as db_session:
# Attempt to fetch Slack bots
try:
bots = list(fetch_slack_bots(db_session=db_session))
except KvKeyNotFoundError:
# No Slackbot tokens, pass (and remove below)
bots = []
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if (tenant_id, bot.id) in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id, bot.id].close())
del self.socket_clients[tenant_id, bot.id]
del self.slack_bot_tokens[tenant_id, bot.id]
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
bots = []
if not bots:
logger.info(
f"Tenant {tenant_id} no longer has Slack bots. Removing."
)
self._remove_tenant(tenant_id)
# NOTE: We release the lock here (in the same scope it was acquired)
if tenant_id in self.redis_locks and not DEV_MODE:
try:
self.redis_locks[tenant_id].release()
del self.redis_locks[tenant_id]
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(
f"Error releasing lock for tenant {tenant_id}: {e}"
)
else:
# Manage or reconnect Slack bot sockets
for bot in bots:
self._manage_clients_per_tenant(
db_session=db_session,
tenant_id=tenant_id,
bot=bot,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _remove_tenant(self, tenant_id: str | None) -> None:
"""
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
(Lock release now happens in `acquire_tenants()`, not here.)
"""
# Close all socket clients for this tenant
for (t_id, slack_bot_id), client in list(self.socket_clients.items()):
if t_id == tenant_id:
asyncio.run(client.close())
del self.socket_clients[(t_id, slack_bot_id)]
del self.slack_bot_tokens[(t_id, slack_bot_id)]
logger.info(
f"Stopped SocketModeClient for tenant: {t_id}, app: {slack_bot_id}"
)
# Remove from active set
if tenant_id in self.tenant_ids:
self.tenant_ids.remove(tenant_id)
def send_heartbeats(self) -> None:
current_time = int(time.time())
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} active tenants")
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
redis_client = get_redis_client(tenant_id=tenant_id)
heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
@@ -422,7 +315,6 @@ class SlackbotHandler:
)
socket_client.connect()
self.socket_clients[tenant_id, slack_bot_id] = socket_client
# Ensure tenant is tracked as active
self.tenant_ids.add(tenant_id)
logger.info(
f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
@@ -430,7 +322,7 @@ class SlackbotHandler:
def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for (tenant_id, slack_bot_id), client in list(self.socket_clients.items()):
for (tenant_id, slack_bot_id), client in self.socket_clients.items():
asyncio.run(client.close())
logger.info(
f"Stopped SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
@@ -448,19 +340,17 @@ class SlackbotHandler:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
self.stop_socket_clients()
# Release locks for all tenants we currently hold
# Release locks for all tenants
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
for tenant_id in list(self.tenant_ids):
if tenant_id in self.redis_locks:
try:
self.redis_locks[tenant_id].release()
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
finally:
del self.redis_locks[tenant_id]
for tenant_id in self.tenant_ids:
try:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(OnyxRedisLocks.SLACK_BOT_LOCK)
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
# Wait for background threads to finish (with a timeout)
# Wait for background threads to finish (with timeout)
logger.info("Waiting for background threads to finish...")
self.acquire_thread.join(timeout=5)
self.heartbeat_thread.join(timeout=5)
@@ -537,36 +427,30 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
# Let the tag flow handle this case, don't reply twice
return False
# Check if this is a bot message (either via bot_profile or bot_message subtype)
is_bot_message = bool(
event.get("bot_profile") or event.get("subtype") == "bot_message"
)
if is_bot_message:
if event.get("bot_profile"):
channel_name, _ = get_channel_name_from_id(
client=client.web_client, channel_id=channel
)
with get_session_with_tenant(client.tenant_id) as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
)
# If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
if (not bot_tag_id or bot_tag_id not in msg) and (
not slack_channel_config
or not slack_channel_config.channel_config.get("respond_to_bots")
):
channel_specific_logger.info(
"Ignoring message from bot since respond_to_bots is disabled"
)
channel_specific_logger.info("Ignoring message from bot")
return False
# Ignore things like channel_join, channel_leave, etc.
# NOTE: "file_share" is just a message with a file attachment, so we
# should not ignore it
message_subtype = event.get("subtype")
if message_subtype not in [None, "file_share", "bot_message"]:
if message_subtype not in [None, "file_share"]:
channel_specific_logger.info(
f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
)

View File

@@ -19,8 +19,9 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_DANSWER_DATETIME_REPLACEMENT_PAT = "[[CURRENT_DATETIME]]"
_BASIC_TIME_STR = "The current date is {datetime_info}."
MOST_BASIC_PROMPT = "You are a helpful AI assistant."
DANSWER_DATETIME_REPLACEMENT = "DANSWER_DATETIME_REPLACEMENT"
BASIC_TIME_STR = "The current date is {datetime_info}."
def get_current_llm_day_time(
@@ -37,36 +38,23 @@ def get_current_llm_day_time(
return f"{formatted_datetime}"
def build_date_time_string() -> str:
return ADDITIONAL_INFO.format(
datetime_info=_BASIC_TIME_STR.format(datetime_info=get_current_llm_day_time())
)
def handle_onyx_date_awareness(
prompt_str: str,
prompt_config: PromptConfig,
add_additional_info_if_no_tag: bool = False,
) -> str:
"""
If there is a [[CURRENT_DATETIME]] tag, replace it with the current date and time no matter what.
If the prompt is datetime aware, and there are no [[CURRENT_DATETIME]] tags, add it to the prompt.
do nothing otherwise.
This can later be expanded to support other tags.
"""
if _DANSWER_DATETIME_REPLACEMENT_PAT in prompt_str:
def add_date_time_to_prompt(prompt_str: str) -> str:
if DANSWER_DATETIME_REPLACEMENT in prompt_str:
return prompt_str.replace(
_DANSWER_DATETIME_REPLACEMENT_PAT,
DANSWER_DATETIME_REPLACEMENT,
get_current_llm_day_time(full_sentence=False, include_day_of_week=True),
)
any_tag_present = any(
_DANSWER_DATETIME_REPLACEMENT_PAT in text
for text in [prompt_str, prompt_config.system_prompt, prompt_config.task_prompt]
)
if add_additional_info_if_no_tag and not any_tag_present:
return prompt_str + build_date_time_string()
return prompt_str
if prompt_str:
return prompt_str + ADDITIONAL_INFO.format(
datetime_info=get_current_llm_day_time()
)
else:
return (
MOST_BASIC_PROMPT
+ " "
+ BASIC_TIME_STR.format(datetime_info=get_current_llm_day_time())
)
def build_task_prompt_reminders(

View File

@@ -30,17 +30,10 @@ class RedisConnectorIndex:
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
TERMINATE_TTL = 600
# used to signal the overall workflow is still active
# there are gaps in time between states where we need some slack
# to correctly transition
# it's difficult to prevent
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
# used to signal that the watchdog is running
WATCHDOG_PREFIX = PREFIX + "_watchdog"
WATCHDOG_TTL = 300
def __init__(
self,
@@ -66,7 +59,6 @@ class RedisConnectorIndex:
)
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
self.watchdog_key = f"{self.WATCHDOG_PREFIX}_{id}/{search_settings_id}"
@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
@@ -118,24 +110,7 @@ class RedisConnectorIndex:
"""This sets a signal. It does not block!"""
# We shouldn't need very long to terminate the spawned task.
# 10 minute TTL is good.
self.redis.set(
f"{self.terminate_key}_{celery_task_id}", 0, ex=self.TERMINATE_TTL
)
def set_watchdog(self, value: bool) -> None:
"""Signal the state of the watchdog."""
if not value:
self.redis.delete(self.watchdog_key)
return
self.redis.set(self.watchdog_key, 0, ex=self.WATCHDOG_TTL)
def watchdog_signaled(self) -> bool:
"""Check the state of the watchdog."""
if self.redis.exists(self.watchdog_key):
return True
return False
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
def set_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
@@ -143,7 +118,7 @@ class RedisConnectorIndex:
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
self.redis.set(self.active_key, 0, ex=3600)
def active(self) -> bool:
if self.redis.exists(self.active_key):

View File

@@ -92,7 +92,7 @@ class RedisConnectorPrune:
if fence_bytes is None:
return None
fence_int = int(cast(bytes, fence_bytes))
fence_int = cast(int, fence_bytes)
return fence_int
@generator_complete.setter

View File

@@ -21,7 +21,6 @@ from onyx.configs.app_configs import REDIS_HOST
from onyx.configs.app_configs import REDIS_PASSWORD
from onyx.configs.app_configs import REDIS_POOL_MAX_CONNECTIONS
from onyx.configs.app_configs import REDIS_PORT
from onyx.configs.app_configs import REDIS_REPLICA_HOST
from onyx.configs.app_configs import REDIS_SSL
from onyx.configs.app_configs import REDIS_SSL_CA_CERTS
from onyx.configs.app_configs import REDIS_SSL_CERT_REQS
@@ -133,32 +132,23 @@ class RedisPool:
_instance: Optional["RedisPool"] = None
_lock: threading.Lock = threading.Lock()
_pool: redis.BlockingConnectionPool
_replica_pool: redis.BlockingConnectionPool
def __new__(cls) -> "RedisPool":
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super(RedisPool, cls).__new__(cls)
cls._instance._init_pools()
cls._instance._init_pool()
return cls._instance
def _init_pools(self) -> None:
def _init_pool(self) -> None:
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
self._replica_pool = RedisPool.create_pool(
host=REDIS_REPLICA_HOST, ssl=REDIS_SSL
)
def get_client(self, tenant_id: str | None) -> Redis:
if tenant_id is None:
tenant_id = "public"
return TenantRedis(tenant_id, connection_pool=self._pool)
def get_replica_client(self, tenant_id: str | None) -> Redis:
if tenant_id is None:
tenant_id = "public"
return TenantRedis(tenant_id, connection_pool=self._replica_pool)
@staticmethod
def create_pool(
host: str = REDIS_HOST,
@@ -222,10 +212,6 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)
def get_redis_replica_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_replica_client(tenant_id)
SSL_CERT_REQS_MAP = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,

View File

@@ -16,7 +16,7 @@ from onyx.context.search.preprocessing.access_filters import (
from onyx.db.document_set import get_document_sets_by_ids
from onyx.db.models import StarterMessageModel as StarterMessage
from onyx.db.models import User
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.llm.factory import get_default_llms
from onyx.prompts.starter_messages import format_persona_starter_message_prompt
@@ -34,11 +34,8 @@ def get_random_chunks_from_doc_sets(
"""
Retrieves random chunks from the specified document sets.
"""
active_search_settings = get_active_search_settings(db_session)
document_index = get_default_document_index(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(curr_ind_name, sec_ind_name)
acl_filters = build_access_filters_for_user(user, db_session)
filters = IndexFilters(document_set=doc_sets, access_control_list=acl_filters)

View File

@@ -6184,7 +6184,7 @@
"chunk_ind": 0
},
{
"url": "https://docs.onyx.app/more/use_cases/support",
"url": "https://docs.onyx.app/more/use_cases/customer_support",
"title": "Customer Support",
"content": "Help your customer support team instantly answer any question across your entire product.\n\nAI Enabled Support\nCustomer support agents have one of the highest breadth jobs. They field requests that cover the entire surface area of the product and need to help your users find success on extremely short timelines. Because they're not the same people who designed or built the system, they often lack the depth of understanding needed - resulting in delays and escalations to other teams. Modern teams are leveraging AI to help their CS team optimize the speed and quality of these critical customer-facing interactions.\n\nThe Importance of Context\nThere are two critical components of AI copilots for customer support. The first is that the AI system needs to be connected with as much information as possible (not just support tools like Zendesk or Intercom) and that the knowledge needs to be as fresh as possible. Sometimes a fix might even be in places rarely checked by CS such as pull requests in a code repository. The second critical component is the ability of the AI system to break down difficult concepts and convoluted processes into more digestible descriptions and for your team members to be able to chat back and forth with the system to build a better understanding.\n\nOnyx takes care of both of these. The system connects up to over 30+ different applications and the knowledge is pulled in constantly so that the information access is always up to date.",
"title_embedding": [

View File

@@ -24,7 +24,7 @@
"chunk_ind": 0
},
{
"url": "https://docs.onyx.app/more/use_cases/support",
"url": "https://docs.onyx.app/more/use_cases/customer_support",
"title": "Customer Support",
"content": "Help your customer support team instantly answer any question across your entire product.\n\nAI Enabled Support\nCustomer support agents have one of the highest breadth jobs. They field requests that cover the entire surface area of the product and need to help your users find success on extremely short timelines. Because they're not the same people who designed or built the system, they often lack the depth of understanding needed - resulting in delays and escalations to other teams. Modern teams are leveraging AI to help their CS team optimize the speed and quality of these critical customer-facing interactions.\n\nThe Importance of Context\nThere are two critical components of AI copilots for customer support. The first is that the AI system needs to be connected with as much information as possible (not just support tools like Zendesk or Intercom) and that the knowledge needs to be as fresh as possible. Sometimes a fix might even be in places rarely checked by CS such as pull requests in a code repository. The second critical component is the ability of the AI system to break down difficult concepts and convoluted processes into more digestible descriptions and for your team members to be able to chat back and forth with the system to build a better understanding.\n\nOnyx takes care of both of these. The system connects up to over 30+ different applications and the knowledge is pulled in constantly so that the information access is always up to date.",
"chunk_ind": 0

View File

@@ -3,7 +3,6 @@ import json
import os
from typing import cast
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.access.models import default_public_access
@@ -24,11 +23,9 @@ from onyx.db.document import check_docs_exist
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import mock_successful_index_attempt
from onyx.db.models import Document as DbDocument
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import IndexBatchParams
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocMetadataAwareIndexChunk
@@ -36,6 +33,7 @@ from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.documents.models import ConnectorBase
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
from onyx.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
@@ -61,7 +59,6 @@ def _create_indexable_chunks(
doc_updated_at=None,
primary_owners=[],
secondary_owners=[],
chunk_count=1,
)
if preprocessed_doc["chunk_ind"] == 0:
ids_to_documents[document.id] = document
@@ -158,7 +155,9 @@ def seed_initial_documents(
logger.info("Embedding model has been updated, skipping")
return
document_index = get_default_document_index(search_settings, None)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
)
# Create a connector so the user can delete it if they want
# or reindex it with a new search model if they want
@@ -219,11 +218,9 @@ def seed_initial_documents(
# Retries here because the index may take a few seconds to become ready
# as we just sent over the Vespa schema and there is a slight delay
if not wait_for_vespa_with_timeout():
logger.error("Vespa did not become ready within the timeout")
raise ValueError("Vespa failed to become ready within the timeout")
document_index.index(
index_with_retries = retry_builder(tries=15)(document_index.index)
index_with_retries(
chunks=chunks,
index_batch_params=IndexBatchParams(
doc_id_to_previous_chunk_cnt={},
@@ -241,12 +238,4 @@ def seed_initial_documents(
db_session=db_session,
)
# Since we bypass the indexing flow, we need to manually update the chunk count
for doc in docs:
db_session.execute(
update(DbDocument)
.where(DbDocument.id == doc.id)
.values(chunk_count=doc.chunk_count)
)
kv_store.store(KV_DOCUMENTS_SEEDED_KEY, True)

View File

@@ -8,7 +8,7 @@ prompts:
# System Prompt (as shown in UI)
system: >
You are a question answering system that is constantly learning and improving.
The current date is [[CURRENT_DATETIME]].
The current date is DANSWER_DATETIME_REPLACEMENT.
You can process and comprehend vast amounts of text and utilize this knowledge to provide
grounded, accurate, and concise answers to diverse queries.
@@ -24,7 +24,7 @@ prompts:
If there are no relevant documents, refer to the chat history and your internal knowledge.
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
# If the [[CURRENT_DATETIME]] is set, the date/time is inserted there instead
# If the DANSWER_DATETIME_REPLACEMENT is set, the date/time is inserted there instead
# Format looks like: "October 16, 2023 14:30"
datetime_aware: true
# Prompts the LLM to include citations in the for [1], [2] etc.
@@ -51,7 +51,7 @@ prompts:
- name: "OnlyLLM"
description: "Chat directly with the LLM!"
system: >
You are a helpful AI assistant. The current date is [[CURRENT_DATETIME]]
You are a helpful AI assistant. The current date is DANSWER_DATETIME_REPLACEMENT
You give concise responses to very simple questions, but provide more thorough responses to
@@ -69,7 +69,7 @@ prompts:
system: >
You are a text summarizing assistant that highlights the most important knowledge from the
context provided, prioritizing the information that relates to the user query.
The current date is [[CURRENT_DATETIME]].
The current date is DANSWER_DATETIME_REPLACEMENT.
You ARE NOT creative and always stick to the provided documents.
If there are no documents, refer to the conversation history.
@@ -87,7 +87,7 @@ prompts:
description: "Recites information from retrieved context! Least creative but most safe!"
system: >
Quote and cite relevant information from provided context based on the user query.
The current date is [[CURRENT_DATETIME]].
The current date is DANSWER_DATETIME_REPLACEMENT.
You only provide quotes that are EXACT substrings from provided documents!

View File

@@ -15,9 +15,6 @@ from onyx.background.celery.celery_utils import get_deletion_attempt_snapshot
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
try_creating_permissions_sync_task,
)
from onyx.background.celery.tasks.external_group_syncing.tasks import (
try_creating_external_group_sync_task,
)
from onyx.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
@@ -42,7 +39,7 @@ from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.db.search_settings import get_active_search_settings_list
from onyx.db.search_settings import get_active_search_settings
from onyx.db.search_settings import get_current_search_settings
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_pool import get_redis_client
@@ -65,7 +62,7 @@ router = APIRouter(prefix="/manage")
@router.get("/admin/cc-pair/{cc_pair_id}/index-attempts")
def get_cc_pair_index_attempts(
cc_pair_id: int,
page_num: int = Query(0, ge=0),
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=1000),
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
@@ -84,7 +81,7 @@ def get_cc_pair_index_attempts(
index_attempts = get_paginated_index_attempts_for_cc_pair_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
page=page_num,
page=page,
page_size=page_size,
)
return PaginatedReturn(
@@ -192,7 +189,7 @@ def update_cc_pair_status(
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
redis_connector.stop.set_fence(True)
search_settings_list: list[SearchSettings] = get_active_search_settings_list(
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
@@ -446,78 +443,6 @@ def sync_cc_pair(
)
@router.get("/admin/cc-pair/{cc_pair_id}/sync-groups")
def get_cc_pair_latest_group_sync(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> datetime | None:
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="cc_pair not found for current user's permissions",
)
return cc_pair.last_time_external_group_sync
@router.post("/admin/cc-pair/{cc_pair_id}/sync-groups")
def sync_cc_pair_groups(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]:
"""Triggers group sync on a particular cc_pair immediately"""
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
r = get_redis_client(tenant_id=tenant_id)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if redis_connector.external_group_sync.fenced:
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="External group sync task already in progress.",
)
logger.info(
f"External group sync cc_pair={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
tasks_created = try_creating_external_group_sync_task(
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
)
if not tasks_created:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="External group sync task creation failed.",
)
return StatusResponse(
success=True,
message="Successfully created the external group sync task.",
)
@router.get("/admin/cc-pair/{cc_pair_id}/get-docs-sync-status")
def get_docs_sync_status(
cc_pair_id: int,

View File

@@ -1,8 +1,5 @@
import mimetypes
import os
import uuid
import zipfile
from io import BytesIO
from typing import cast
from fastapi import APIRouter
@@ -389,43 +386,10 @@ def upload_files(
for file in files:
if not file.filename:
raise HTTPException(status_code=400, detail="File name cannot be empty")
# Skip directories and known macOS metadata entries
def should_process_file(file_path: str) -> bool:
normalized_path = os.path.normpath(file_path)
return not any(part.startswith(".") for part in normalized_path.split(os.sep))
try:
file_store = get_default_file_store(db_session)
deduped_file_paths = []
for file in files:
if file.content_type and file.content_type.startswith("application/zip"):
with zipfile.ZipFile(file.file, "r") as zf:
for file_info in zf.namelist():
if zf.getinfo(file_info).is_dir():
continue
if not should_process_file(file_info):
continue
sub_file_bytes = zf.read(file_info)
sub_file_name = os.path.join(str(uuid.uuid4()), file_info)
deduped_file_paths.append(sub_file_name)
mime_type, __ = mimetypes.guess_type(file_info)
if mime_type is None:
mime_type = "application/octet-stream"
file_store.save_file(
file_name=sub_file_name,
content=BytesIO(sub_file_bytes),
display_name=os.path.basename(file_info),
file_origin=FileOrigin.CONNECTOR,
file_type=mime_type,
)
continue
file_path = os.path.join(str(uuid.uuid4()), cast(str, file.filename))
deduped_file_paths.append(file_path)
file_store.save_file(

View File

@@ -32,7 +32,10 @@ def get_document_info(
db_session: Session = Depends(get_session),
) -> DocumentInfo:
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(search_settings, None)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
)
user_acl_filters = build_access_filters_for_user(user, db_session)
inference_chunks = document_index.id_based_retrieval(
@@ -76,7 +79,10 @@ def get_chunk_info(
db_session: Session = Depends(get_session),
) -> ChunkInfo:
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(search_settings, None)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
)
user_acl_filters = build_access_filters_for_user(user, db_session)
chunk_request = VespaChunkRequest(

View File

@@ -357,7 +357,6 @@ class ConnectorCredentialPairDescriptor(BaseModel):
name: str | None = None
connector: ConnectorSnapshot
credential: CredentialSnapshot
access_type: AccessType
class RunConnectorRequest(BaseModel):

View File

@@ -68,7 +68,6 @@ class DocumentSet(BaseModel):
credential=CredentialSnapshot.from_credential_db_model(
cc_pair.credential
),
access_type=cc_pair.access_type,
)
for cc_pair in document_set_model.connector_credential_pairs
],

View File

@@ -7,7 +7,6 @@ from fastapi import HTTPException
from fastapi import Query
from fastapi import UploadFile
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
@@ -192,7 +191,8 @@ def create_persona(
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
system_prompt=persona_upsert_request.system_prompt,
task_prompt=persona_upsert_request.task_prompt,
datetime_aware=persona_upsert_request.datetime_aware,
# TODO: The PersonaUpsertRequest should provide the value for datetime_aware
datetime_aware=False,
include_citations=persona_upsert_request.include_citations,
prompt_id=prompt_id,
)
@@ -236,7 +236,8 @@ def update_persona(
db_session=db_session,
user=user,
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
datetime_aware=persona_upsert_request.datetime_aware,
# TODO: The PersonaUpsertRequest should provide the value for datetime_aware
datetime_aware=False,
system_prompt=persona_upsert_request.system_prompt,
task_prompt=persona_upsert_request.task_prompt,
include_citations=persona_upsert_request.include_citations,
@@ -276,14 +277,8 @@ def create_label(
_: User | None = Depends(current_user),
) -> PersonaLabelResponse:
"""Create a new assistant label"""
try:
label_model = create_assistant_label(name=label.name, db_session=db)
return PersonaLabelResponse.from_model(label_model)
except IntegrityError:
raise HTTPException(
status_code=400,
detail=f"Label with name '{label.name}' already exists. Please choose a different name.",
)
label_model = create_assistant_label(name=label.name, db_session=db)
return PersonaLabelResponse.from_model(label_model)
@admin_router.patch("/label/{label_id}")

View File

@@ -60,7 +60,6 @@ class PersonaUpsertRequest(BaseModel):
description: str
system_prompt: str
task_prompt: str
datetime_aware: bool
document_set_ids: list[int]
num_chunks: float
include_citations: bool

View File

@@ -41,16 +41,6 @@ def _validate_tool_definition(definition: dict[str, Any]) -> None:
raise HTTPException(status_code=400, detail=str(e))
def _validate_auth_settings(tool_data: CustomToolCreate | CustomToolUpdate) -> None:
if tool_data.passthrough_auth and tool_data.custom_headers:
for header in tool_data.custom_headers:
if header.key.lower() == "authorization":
raise HTTPException(
status_code=400,
detail="Cannot use passthrough auth with custom authorization headers",
)
@admin_router.post("/custom")
def create_custom_tool(
tool_data: CustomToolCreate,
@@ -58,7 +48,6 @@ def create_custom_tool(
user: User | None = Depends(current_admin_user),
) -> ToolSnapshot:
_validate_tool_definition(tool_data.definition)
_validate_auth_settings(tool_data)
tool = create_tool(
name=tool_data.name,
description=tool_data.description,
@@ -66,7 +55,6 @@ def create_custom_tool(
custom_headers=tool_data.custom_headers,
user_id=user.id if user else None,
db_session=db_session,
passthrough_auth=tool_data.passthrough_auth,
)
return ToolSnapshot.from_model(tool)
@@ -80,7 +68,6 @@ def update_custom_tool(
) -> ToolSnapshot:
if tool_data.definition:
_validate_tool_definition(tool_data.definition)
_validate_auth_settings(tool_data)
updated_tool = update_tool(
tool_id=tool_id,
name=tool_data.name,
@@ -89,7 +76,6 @@ def update_custom_tool(
custom_headers=tool_data.custom_headers,
user_id=user.id if user else None,
db_session=db_session,
passthrough_auth=tool_data.passthrough_auth,
)
return ToolSnapshot.from_model(updated_tool)

View File

@@ -13,7 +13,6 @@ class ToolSnapshot(BaseModel):
display_name: str
in_code_tool_id: str | None
custom_headers: list[Any] | None
passthrough_auth: bool
@classmethod
def from_model(cls, tool: Tool) -> "ToolSnapshot":
@@ -25,7 +24,6 @@ class ToolSnapshot(BaseModel):
display_name=tool.display_name or tool.name,
in_code_tool_id=tool.in_code_tool_id,
custom_headers=tool.custom_headers,
passthrough_auth=tool.passthrough_auth,
)
@@ -39,7 +37,6 @@ class CustomToolCreate(BaseModel):
description: str | None = None
definition: dict[str, Any]
custom_headers: list[Header] | None = None
passthrough_auth: bool
class CustomToolUpdate(BaseModel):
@@ -47,4 +44,3 @@ class CustomToolUpdate(BaseModel):
description: str | None = None
definition: dict[str, Any] | None = None
custom_headers: list[Header] | None = None
passthrough_auth: bool | None = None

View File

@@ -10,7 +10,6 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_for_user
from onyx.db.llm import fetch_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
@@ -196,7 +195,5 @@ def list_llm_provider_basics(
) -> list[LLMProviderDescriptor]:
return [
LLMProviderDescriptor.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers_for_user(
db_session, user
)
for llm_provider_model in fetch_existing_llm_providers(db_session, user)
]

View File

@@ -44,6 +44,7 @@ class UserPreferences(BaseModel):
chosen_assistants: list[int] | None = None
hidden_assistants: list[int] = []
visible_assistants: list[int] = []
recent_assistants: list[int] | None = None
default_model: str | None = None
auto_scroll: bool | None = None
pinned_assistants: list[int] | None = None

View File

@@ -22,7 +22,6 @@ from onyx.db.search_settings import get_embedding_provider_from_provider_type
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_current_search_settings
from onyx.db.search_settings import update_search_settings_status
from onyx.document_index.document_index_utils import get_multipass_config
from onyx.document_index.factory import get_default_document_index
from onyx.file_processing.unstructured import delete_unstructured_api_key
from onyx.file_processing.unstructured import get_unstructured_api_key
@@ -98,9 +97,10 @@ def set_new_search_settings(
)
# Ensure Vespa has the new index immediately
get_multipass_config(search_settings)
get_multipass_config(new_search_settings)
document_index = get_default_document_index(search_settings, new_search_settings)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=new_search_settings.index_name,
)
document_index.ensure_indices_exist(
index_embedding_dim=search_settings.model_dim,

Some files were not shown because too many files have changed in this diff Show More