Compare commits

..

3 Commits

Author SHA1 Message Date
pablodanswer
dc089d6842 p 2025-01-25 18:16:21 -08:00
pablodanswer
ef1fc19de7 k 2025-01-25 18:16:01 -08:00
pablodanswer
842fcf4156 update 2025-01-25 18:15:32 -08:00
210 changed files with 2932 additions and 6546 deletions

View File

@@ -8,8 +8,6 @@ on: push
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
MOCK_LLM_RESPONSE: true
jobs:
playwright-tests:

View File

@@ -21,10 +21,10 @@ jobs:
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.17.0
version: v3.14.4
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.7.0
uses: helm/chart-testing-action@v2.6.1
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
@@ -37,6 +37,22 @@ jobs:
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# rkuo: I don't think we need python?
# - name: Set up Python
# uses: actions/setup-python@v5
# with:
# python-version: '3.11'
# cache: 'pip'
# cache-dependency-path: |
# backend/requirements/default.txt
# backend/requirements/dev.txt
# backend/requirements/model_server.txt
# - run: |
# python -m pip install --upgrade pip
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'
@@ -46,7 +62,7 @@ jobs:
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.12.0
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
if: steps.list-changed.outputs.changed == 'true'

View File

@@ -39,12 +39,6 @@ env:
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
# Sharepoint
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/

View File

@@ -1,36 +0,0 @@
"""add chat session specific temperature override
Revision ID: 2f80c6a2550f
Revises: 33ea50e88f24
Create Date: 2025-01-31 10:30:27.289646
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2f80c6a2550f"
down_revision = "33ea50e88f24"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
)
op.add_column(
"user",
sa.Column(
"temperature_override_enabled",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
def downgrade() -> None:
op.drop_column("chat_session", "temperature_override")
op.drop_column("user", "temperature_override_enabled")

View File

@@ -1,80 +0,0 @@
"""foreign key input prompts
Revision ID: 33ea50e88f24
Revises: a6df6b88ef81
Create Date: 2025-01-29 10:54:22.141765
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "33ea50e88f24"
down_revision = "a6df6b88ef81"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Safely drop constraints if exists
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
"""
)
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
"""
)
# Recreate with ON DELETE CASCADE
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Drop the new FKs with ondelete
op.drop_constraint(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
op.drop_constraint(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
# Recreate them without cascading
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
)

View File

@@ -1,37 +0,0 @@
"""lowercase_user_emails
Revision ID: 4d58345da04a
Revises: f1ca58b2f2ec
Create Date: 2025-01-29 07:48:46.784041
"""
from alembic import op
from sqlalchemy.sql import text
# revision identifiers, used by Alembic.
revision = "4d58345da04a"
down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get database connection
connection = op.get_bind()
# Update all user emails to lowercase
connection.execute(
text(
"""
UPDATE "user"
SET email = LOWER(email)
WHERE email != LOWER(email)
"""
)
)
def downgrade() -> None:
# Cannot restore original case of emails
pass

View File

@@ -1,29 +0,0 @@
"""remove recent assistants
Revision ID: a6df6b88ef81
Revises: 4d58345da04a
Create Date: 2025-01-29 10:25:52.790407
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a6df6b88ef81"
down_revision = "4d58345da04a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_column("user", "recent_assistants")
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
),
)

View File

@@ -32,7 +32,6 @@ def perform_ttl_management_task(
@celery_app.task(
name="check_ttl_management_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(*, tenant_id: str | None) -> None:
@@ -57,7 +56,6 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
@celery_app.task(
name="autogenerate_usage_report_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:

View File

@@ -13,7 +13,6 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -258,7 +257,6 @@ def _fetch_all_page_restrictions(
slim_docs: list[SlimDocument],
space_permissions_by_space_key: dict[str, ExternalAccess],
is_cloud: bool,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
For all pages, if a page has restrictions, then use those restrictions.
@@ -267,12 +265,6 @@ def _fetch_all_page_restrictions(
document_restrictions: list[DocExternalAccess] = []
for slim_doc in slim_docs:
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
if slim_doc.perm_sync_data is None:
raise ValueError(
f"No permission sync data found for document {slim_doc.id}"
@@ -342,7 +334,7 @@ def _fetch_all_page_restrictions(
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -367,12 +359,6 @@ def confluence_doc_sync(
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents():
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync", 1)
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
@@ -381,5 +367,4 @@ def confluence_doc_sync(
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
is_cloud=is_cloud,
callback=callback,
)

View File

@@ -14,8 +14,6 @@ def _build_group_member_email_map(
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user_result in confluence_client.paginated_cql_user_retrieval():
logger.debug(f"Processing groups for user: {user_result}")
user = user_result.get("user", {})
if not user:
logger.warning(f"user result missing user field: {user_result}")
@@ -35,17 +33,10 @@ def _build_group_member_email_map(
logger.warning(f"user result missing email field: {user_result}")
continue
all_users_groups: set[str] = set()
for group in confluence_client.paginated_groups_by_user_retrieval(user):
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
group_id = group["name"]
group_member_emails.setdefault(group_id, set()).add(email)
all_users_groups.add(group_id)
if not group_member_emails:
logger.warning(f"No groups found for user with email: {email}")
else:
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
return group_member_emails

View File

@@ -6,7 +6,6 @@ from onyx.access.models import ExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -29,7 +28,7 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -45,12 +44,6 @@ def gmail_doc_sync(
document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gmail_doc_sync: Stop signal detected")
callback.progress("gmail_doc_sync", 1)
if slim_doc.perm_sync_data is None:
logger.warning(f"No permissions found for document {slim_doc.id}")
continue

View File

@@ -10,7 +10,6 @@ from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -43,22 +42,24 @@ def _fetch_permissions_for_permission_ids(
if not permission_info or not doc_id:
return []
# Check cache first for all permission IDs
permissions = [
_PERMISSION_ID_PERMISSION_MAP[pid]
for pid in permission_ids
if pid in _PERMISSION_ID_PERMISSION_MAP
]
# If we found all permissions in cache, return them
if len(permissions) == len(permission_ids):
return permissions
owner_email = permission_info.get("owner_email")
drive_service = get_drive_service(
creds=google_drive_connector.creds,
user_email=(owner_email or google_drive_connector.primary_admin_email),
)
# Otherwise, fetch all permissions and update cache
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
@@ -68,6 +69,7 @@ def _fetch_permissions_for_permission_ids(
)
permissions_for_doc_id = []
# Update cache and return all permissions
for permission in fetched_permissions:
permissions_for_doc_id.append(permission)
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
@@ -129,7 +131,7 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -147,12 +149,6 @@ def gdrive_doc_sync(
document_external_accesses = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
callback.progress("gdrive_doc_sync", 1)
ext_access = _get_permissions_from_slim_doc(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,

View File

@@ -7,7 +7,6 @@ from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -15,7 +14,7 @@ logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> dict[str, list[str]]:
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
@@ -25,14 +24,6 @@ def _get_slack_document_ids_and_channels(
channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if callback:
if callback.should_stop():
raise RuntimeError(
"_get_slack_document_ids_and_channels: Stop signal detected"
)
callback.progress("_get_slack_document_ids_and_channels", 1)
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
@@ -123,7 +114,7 @@ def _fetch_channel_permissions(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -136,7 +127,7 @@ def slack_doc_sync(
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
channel_doc_map = _get_slack_document_ids_and_channels(
cc_pair=cc_pair, callback=callback
cc_pair=cc_pair,
)
workspace_permissions = _fetch_workspace_permissions(
user_id_to_email_map=user_id_to_email_map,

View File

@@ -15,13 +15,11 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
# Defining the input/output types for the sync functions
DocSyncFuncType = Callable[
[
ConnectorCredentialPair,
IndexingHeartbeatInterface | None,
],
list[DocExternalAccess],
]

View File

@@ -111,7 +111,6 @@ async def login_as_anonymous_user(
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie("fastapiusersauth")
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,

View File

@@ -58,7 +58,6 @@ class UserGroup(BaseModel):
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential
),
access_type=cc_pair_relationship.cc_pair.access_type,
)
for cc_pair_relationship in user_group_model.cc_pair_relationships
if cc_pair_relationship.is_current

View File

@@ -42,10 +42,6 @@ class UserCreate(schemas.BaseUserCreate):
tenant_id: str | None = None
class UserUpdateWithRole(schemas.BaseUserUpdate):
role: UserRole
class UserUpdate(schemas.BaseUserUpdate):
"""
Role updates are not allowed through the user update endpoint for security reasons

View File

@@ -57,7 +57,7 @@ from onyx.auth.invited_users import get_invited_users
from onyx.auth.schemas import AuthBackend
from onyx.auth.schemas import UserCreate
from onyx.auth.schemas import UserRole
from onyx.auth.schemas import UserUpdateWithRole
from onyx.auth.schemas import UserUpdate
from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import AUTH_TYPE
@@ -216,6 +216,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
async def create(
@@ -245,8 +246,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
referral_source=referral_source,
request=request,
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
@@ -265,16 +268,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.role.is_web_login() and user_create.role.is_web_login():
user_update = UserUpdateWithRole(
user_update = UserUpdate(
password=user_create.password,
is_verified=user_create.is_verified,
role=user_create.role,
)
user = await self.update(user_update, user)
else:
@@ -282,6 +285,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:

View File

@@ -24,7 +24,6 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_sqlalchemy_engine
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_delete import RedisConnectorDelete
@@ -198,8 +197,7 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
"""Waits for redis to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
r = get_redis_client(tenant_id=None)
@@ -318,8 +316,6 @@ def on_worker_ready(sender: Any, **kwargs: Any) -> None:
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
HttpxPool.close_all()
if not celery_is_worker_primary(sender):
return

View File

@@ -1,5 +1,6 @@
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -7,6 +8,7 @@ from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
@@ -130,25 +132,21 @@ class DynamicTenantScheduler(PersistentScheduler):
# get current schedule and extract current tenants
current_schedule = self.schedule.items()
# there are no more per tenant beat tasks, so comment this out
# NOTE: we may not actualy need this scheduler any more and should
# test reverting to a regular beat schedule implementation
current_tenants = set()
for task_name, _ in current_schedule:
task_name = cast(str, task_name)
if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
continue
# current_tenants = set()
# for task_name, _ in current_schedule:
# task_name = cast(str, task_name)
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
# continue
if "_" in task_name:
# example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
# -> "12345678-abcd-efgh-ijkl-12345678"
current_tenants.add(task_name.split("_")[-1])
logger.info(f"Found {len(current_tenants)} existing items in schedule")
# if "_" in task_name:
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
# # -> "12345678-abcd-efgh-ijkl-12345678"
# current_tenants.add(task_name.split("_")[-1])
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
# for tenant_id in tenant_ids:
# if tenant_id not in current_tenants:
# logger.info(f"Processing new tenant: {tenant_id}")
for tenant_id in tenant_ids:
if tenant_id not in current_tenants:
logger.info(f"Processing new tenant: {tenant_id}")
new_schedule = self._generate_schedule(tenant_ids)

View File

@@ -10,10 +10,6 @@ from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
@@ -58,23 +54,12 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
logger.info("worker_init signal received.")
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
if MANAGED_VESPA:
httpx_init_vespa_pool(
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
ssl_cert=VESPA_CLOUD_CERT_PATH,
ssl_key=VESPA_CLOUD_KEY_PATH,
)
else:
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -91,28 +91,6 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
return False
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
"""This is a redis specific way to build a list of tasks in a queue.
This helps us read the queue once and then efficiently look for missing tasks
in the queue.
"""
task_set: set[str] = set()
for priority in range(len(OnyxCeleryPriority)):
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
for task in tasks:
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
task_id = task_dict.get("headers", {}).get("id")
if task_id:
task_set.add(task_id)
return task_set
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
"""Returns a list of current workers containing name_filter, or all workers if
name_filter is None.

View File

@@ -1,13 +1,10 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
import httpx
from sqlalchemy.orm import Session
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
@@ -20,7 +17,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import TaskStatus
from onyx.db.models import TaskQueueState
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.server.documents.models import DeletionAttemptSnapshot
@@ -158,25 +154,3 @@ def celery_is_worker_primary(worker: Any) -> bool:
return True
return False
def httpx_init_vespa_pool(
max_keepalive_connections: int,
timeout: int = VESPA_REQUEST_TIMEOUT,
ssl_cert: str | None = None,
ssl_key: str | None = None,
) -> None:
httpx_cert = None
httpx_verify = False
if ssl_cert and ssl_key:
httpx_cert = cast(tuple[str, str], (ssl_cert, ssl_key))
httpx_verify = True
HttpxPool.init_client(
name="vespa",
cert=httpx_cert,
verify=httpx_verify,
timeout=timeout,
http2=False,
limits=httpx.Limits(max_keepalive_connections=max_keepalive_connections),
)

View File

@@ -16,10 +16,6 @@ from shared_configs.configs import MULTI_TENANT
# it's only important that they run relatively regularly
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# hack to slow down task dispatch in the cloud until
# we have a better implementation (backpressure, etc)
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 8
# tasks that only run in the cloud
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
# by the DynamicTenantScheduler
@@ -28,7 +24,7 @@ cloud_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
"schedule": timedelta(hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"schedule": timedelta(hours=1),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
@@ -39,7 +35,7 @@ cloud_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
@@ -51,7 +47,7 @@ cloud_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-connector-deletion",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
@@ -63,7 +59,7 @@ cloud_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-vespa-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
@@ -75,7 +71,7 @@ cloud_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-prune",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
@@ -87,7 +83,7 @@ cloud_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-vespa-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
@@ -99,7 +95,7 @@ cloud_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=30 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
@@ -111,7 +107,7 @@ cloud_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-external-group-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
@@ -123,7 +119,7 @@ cloud_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-background-processes",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(minutes=5 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"schedule": timedelta(minutes=5),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
@@ -141,9 +137,7 @@ if LLM_MODEL_UPDATE_API_URL:
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-llm-model-update",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(
hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER
), # Check every hour
"schedule": timedelta(hours=1), # Check every hour
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
@@ -227,7 +221,7 @@ if not MULTI_TENANT:
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"schedule": timedelta(minutes=15),
"schedule": timedelta(minutes=5),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,

View File

@@ -33,7 +33,6 @@ class TaskDependencyError(RuntimeError):
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -140,6 +139,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
submitted=datetime.now(timezone.utc),
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
)
redis_connector.delete.set_fence(fence_payload)
try:
@@ -178,13 +184,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
)
if tasks_generated is None:
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
)
except TaskDependencyError:
redis_connector.delete.set_fence(None)
raise

View File

@@ -3,18 +3,14 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from time import sleep
from typing import cast
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.document import upsert_document_external_perms
@@ -25,10 +21,6 @@ from ee.onyx.external_permissions.sync_params import (
)
from onyx.access.models import DocExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
@@ -39,32 +31,21 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyncPayload
from onyx.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
)
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -76,9 +57,6 @@ LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
"""Jobs / utils for kicking off doc permissions sync tasks."""
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external doc permissions sync is due."""
@@ -113,17 +91,11 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
# TODO(rkuo): merge into check function after lookup table for fences is added
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client(tenant_id=tenant_id)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
@@ -144,32 +116,14 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
if _is_external_doc_permissions_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
lock_beat.reacquire()
for cc_pair_id in cc_pair_ids_to_sync:
payload_id = try_creating_permissions_sync_task(
tasks_created = try_creating_permissions_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not payload_id:
if not tasks_created:
continue
task_logger.info(
f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}"
)
# we want to run this less frequently than the overall task
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES):
# clear any permission fences that don't have associated celery tasks in progress
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
validate_permission_sync_fences(tenant_id, r, r_celery, lock_beat)
except Exception:
task_logger.exception(
"Exception while validating permission sync fences"
)
r.set(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=60)
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -188,15 +142,13 @@ def try_creating_permissions_sync_task(
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> str | None:
"""Returns a randomized payload id on success.
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
LOCK_TIMEOUT = 30
payload_id: str | None = None
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
@@ -221,25 +173,6 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_PERMISSIONS,
)
# set a basic fence to start
redis_connector.permissions.set_active()
payload = RedisConnectorPermissionSyncPayload(
id=make_short_id(),
submitted=datetime.now(timezone.utc),
started=None,
celery_task_id=None,
)
redis_connector.permissions.set_fence(payload)
result = app.send_task(
OnyxCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
kwargs=dict(
@@ -251,12 +184,12 @@ def try_creating_permissions_sync_task(
priority=OnyxCeleryPriority.HIGH,
)
# fill in the celery task id
redis_connector.permissions.set_active()
payload.celery_task_id = result.id
redis_connector.permissions.set_fence(payload)
# set a basic fence to start
payload = RedisConnectorPermissionSyncPayload(
started=None, celery_task_id=result.id
)
payload_id = payload.celery_task_id
redis_connector.permissions.set_fence(payload)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
return None
@@ -264,7 +197,7 @@ def try_creating_permissions_sync_task(
if lock.owned():
lock.release()
return payload_id
return 1
@shared_task(
@@ -285,8 +218,6 @@ def connector_permission_sync_generator_task(
This task assumes that the task has already been properly fenced
"""
LoggerContextVars.reset()
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
doc_permission_sync_ctx_dict["request_id"] = self.request.id
@@ -374,17 +305,12 @@ def connector_permission_sync_generator_task(
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
new_payload = RedisConnectorPermissionSyncPayload(
id=payload.id,
submitted=payload.submitted,
started=datetime.now(timezone.utc),
celery_task_id=payload.celery_task_id,
)
redis_connector.permissions.set_fence(new_payload)
callback = PermissionSyncCallback(redis_connector, lock, r)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
cc_pair, callback
)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
task_logger.info(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
@@ -434,8 +360,6 @@ def update_external_document_permissions_task(
connector_id: int,
credential_id: int,
) -> bool:
start = time.monotonic()
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
@@ -465,330 +389,12 @@ def update_external_document_permissions_task(
document_ids=[doc_id],
)
elapsed = time.monotonic() - start
task_logger.info(
f"connector_id={connector_id} "
f"doc={doc_id} "
f"action=update_permissions "
f"elapsed={elapsed:.2f}"
logger.debug(
f"Successfully synced postgres document permissions for {doc_id}"
)
return True
except Exception:
task_logger.exception(
f"Exception in update_external_document_permissions_task: "
f"connector_id={connector_id} "
f"doc_id={doc_id}"
logger.exception(
f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}"
)
return False
return True
def validate_permission_sync_fences(
tenant_id: str | None,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
# building lookup table can be expensive, so we won't bother
# validating until the queue is small
PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024
queue_len = celery_get_queue_length(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN:
return
queued_upsert_tasks = celery_get_queued_task_ids(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
reserved_generator_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
lock_beat.reacquire()
validate_permission_sync_fence(
tenant_id,
key_bytes,
queued_upsert_tasks,
reserved_generator_tasks,
r,
r_celery,
)
return
def validate_permission_sync_fence(
tenant_id: str | None,
key_bytes: bytes,
queued_tasks: set[str],
reserved_tasks: set[str],
r: Redis,
r_celery: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
queued_tasks: the celery queue of lightweight permission sync tasks
reserved_tasks: prefetched tasks for sync task generator
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_permission_sync_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.permissions.fenced:
return
# in the cloud, the payload format may have changed ...
# it's a little sloppy, but just reset the fence for now if that happens
# TODO: add intentional cleanup/abort logic
try:
payload = redis_connector.permissions.payload
except ValidationError:
task_logger.exception(
"validate_permission_sync_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.permissions.reset()
return
if not payload:
return
if not payload.celery_task_id:
return
# OK, there's actually something for us to validate
# either the generator task must be in flight or its subtasks must be
found = celery_find_task(
payload.celery_task_id,
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
r_celery,
)
if found:
# the celery task exists in the redis queue
redis_connector.permissions.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within a worker
redis_connector.permissions.set_active()
return
# look up every task in the current taskset in the celery queue
# every entry in the taskset should have an associated entry in the celery task queue
# because we get the celery tasks first, the entries in our own permissions taskset
# should be roughly a subset of the tasks in celery
# this check isn't very exact, but should be sufficient over a period of time
# A single successful check over some number of attempts is sufficient.
# TODO: if the number of tasks in celery is much lower than than the taskset length
# we might be able to shortcut the lookup since by definition some of the tasks
# must not exist in celery.
tasks_scanned = 0
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
for member in r.sscan_iter(redis_connector.permissions.taskset_key):
tasks_scanned += 1
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
continue
if member_str in reserved_tasks:
continue
tasks_not_in_celery += 1
task_logger.info(
"validate_permission_sync_fence task check: "
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
if tasks_not_in_celery == 0:
redis_connector.permissions.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector.permissions.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
task_logger.warning(
"validate_permission_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.permissions.reset()
return
class PermissionSyncCallback(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
def __init__(
self,
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.redis_connector: RedisConnector = redis_connector
self.redis_lock: RedisLock = redis_lock
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = "PermissionSyncCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
def should_stop(self) -> bool:
if self.redis_connector.stop.fenced:
return True
return False
def progress(self, tag: str, amount: int) -> None:
try:
self.redis_connector.permissions.set_active()
current_time = time.monotonic()
if current_time - self.last_lock_monotonic >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
):
self.redis_lock.reacquire()
self.last_lock_reacquire = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.last_tag = tag
except LockError:
logger.exception(
f"PermissionSyncCallback - lock.reacquire exceptioned: "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
redis_lock_dump(self.redis_lock, self.redis_client)
raise
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
def monitor_ccpair_permissions_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.permissions.fenced:
return
initial = redis_connector.permissions.generator_complete
if initial is None:
return
try:
payload = redis_connector.permissions.payload
except ValidationError:
task_logger.exception(
"Permissions sync payload failed to validate. "
"Schema may have been updated."
)
return
if not payload:
return
remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: "
f"cc_pair={cc_pair_id} "
f"id={payload.id} "
f"remaining={remaining} "
f"initial={initial}"
)
if remaining > 0:
return
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), payload.started)
task_logger.info(
f"Permissions sync finished: "
f"cc_pair={cc_pair_id} "
f"id={payload.id} "
f"num_synced={initial}"
)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_PERMISSIONS,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial,
)
redis_connector.permissions.reset()

View File

@@ -1,4 +1,3 @@
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -10,7 +9,6 @@ from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
@@ -22,12 +20,9 @@ from ee.onyx.external_permissions.sync_params import (
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
)
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@@ -38,18 +33,12 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from onyx.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -102,17 +91,12 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
# r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
@@ -147,7 +131,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
if _is_external_group_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
lock_beat.reacquire()
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
@@ -156,23 +139,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
continue
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
# we want to run this less frequently than the overall task
# lock_beat.reacquire()
# if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
# # clear any indexing fences that don't have associated celery tasks in progress
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# # or be currently executing
# try:
# validate_external_group_sync_fences(
# tenant_id, self.app, r, r_celery, lock_beat
# )
# except Exception:
# task_logger.exception(
# "Exception while validating external group sync fences"
# )
# r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -215,12 +181,6 @@ def try_creating_external_group_sync_task(
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
payload = RedisConnectorExternalGroupSyncPayload(
submitted=datetime.now(timezone.utc),
started=None,
celery_task_id=None,
)
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
result = app.send_task(
@@ -234,17 +194,13 @@ def try_creating_external_group_sync_task(
priority=OnyxCeleryPriority.HIGH,
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_GROUP,
)
payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)
payload.celery_task_id = result.id
redis_connector.external_group_sync.set_fence(payload)
except Exception:
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
@@ -271,7 +227,7 @@ def connector_external_group_sync_generator_task(
tenant_id: str | None,
) -> None:
"""
External group sync task for a given connector credential pair
Permission sync task that handles external group syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
@@ -279,59 +235,19 @@ def connector_external_group_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
# before the primary worker can finalize the fence
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
if not redis_connector.external_group_sync.fenced: # The fence must exist
raise ValueError(
f"connector_external_group_sync_generator_task - fence not found: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
payload = redis_connector.external_group_sync.payload # The payload must exist
if not payload:
raise ValueError(
"connector_external_group_sync_generator_task: payload invalid or not found"
)
if payload.celery_task_id is None:
logger.info(
f"connector_external_group_sync_generator_task - Waiting for fence: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
time.sleep(1)
continue
logger.info(
f"connector_external_group_sync_generator_task - Fence found, continuing...: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
break
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
try:
payload.started = datetime.now(timezone.utc)
redis_connector.external_group_sync.set_fence(payload)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
@@ -372,26 +288,11 @@ def connector_external_group_sync_generator_task(
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_GROUP,
sync_status=SyncStatus.SUCCESS,
)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_GROUP,
sync_status=SyncStatus.FAILED,
)
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
raise e
@@ -400,135 +301,3 @@ def connector_external_group_sync_generator_task(
redis_connector.external_group_sync.set_fence(None)
if lock.owned():
lock.release()
def validate_external_group_sync_fences(
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
reserved_sync_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_external_group_sync_fence(
tenant_id,
key_bytes,
reserved_sync_tasks,
r_celery,
db_session,
)
return
def validate_external_group_sync_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_external_group_sync_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.external_group_sync.fenced:
return
payload = redis_connector.external_group_sync.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
# if redis_connector_index.active():
# return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
"validate_external_group_sync_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector.external_group_sync.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
if found:
# the celery task exists in the redis queue
# redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
# redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
# if redis_connector_index.active():
# return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
"validate_external_group_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.external_group_sync.reset()
return

View File

@@ -15,7 +15,6 @@ from redis import Redis
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.indexing.utils import _should_index
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
@@ -23,9 +22,6 @@ from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_ta
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
@@ -41,14 +37,14 @@ from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.search_settings import get_active_search_settings_list
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_active_search_settings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.swap_index import check_index_swap
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
@@ -73,7 +69,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
tasks_created = 0
locked = False
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client_replica = get_redis_replica_client(tenant_id=tenant_id)
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
@@ -124,7 +119,9 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
search_settings_list = get_active_search_settings_list(db_session)
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
@@ -230,7 +227,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# or be currently executing
try:
validate_indexing_fences(
tenant_id, redis_client_replica, redis_client_celery, lock_beat
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
)
except Exception:
task_logger.exception("Exception while validating indexing fences")
@@ -304,14 +301,6 @@ def connector_indexing_task(
attempt_found = False
n_final_progress: int | None = None
# 20 is the documented default for httpx max_keepalive_connections
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)

View File

@@ -291,20 +291,17 @@ def validate_indexing_fence(
def validate_indexing_fences(
tenant_id: str | None,
r_replica: Redis,
celery_app: Celery,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
"""Validates all indexing fences for this tenant ... aka makes sure
indexing tasks sent to celery are still in flight.
"""
reserved_indexing_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
# Use replica for this because the worst thing that happens
# is that we don't run the validation on this pass
for key_bytes in r_replica.scan_iter(
# validate all existing indexing jobs
for key_bytes in r.scan_iter(
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
lock_beat.reacquire()

View File

@@ -54,7 +54,6 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,

View File

@@ -4,7 +4,6 @@ from collections.abc import Callable
from datetime import timedelta
from itertools import islice
from typing import Any
from typing import Literal
from celery import shared_task
from celery import Task
@@ -27,20 +26,17 @@ from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import IndexAttempt
from onyx.db.models import SyncRecord
from onyx.db.models import UserGroup
from onyx.db.search_settings import get_active_search_settings_list
from onyx.db.search_settings import get_active_search_settings
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
@@ -53,17 +49,6 @@ _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT = (
"monitoring_connector_index_attempt_run_success:{cc_pair_id}:{index_attempt_id}"
)
_FINAL_METRIC_KEY_FMT = "sync_final_metrics:{sync_type}:{entity_id}:{sync_record_id}"
_SYNC_START_LATENCY_KEY_FMT = (
"sync_start_latency:{sync_type}:{entity_id}:{sync_record_id}"
)
_CONNECTOR_START_TIME_KEY_FMT = "connector_start_time:{cc_pair_id}:{index_attempt_id}"
_CONNECTOR_END_TIME_KEY_FMT = "connector_end_time:{cc_pair_id}:{index_attempt_id}"
_SYNC_START_TIME_KEY_FMT = "sync_start_time:{sync_type}:{entity_id}:{sync_record_id}"
_SYNC_END_TIME_KEY_FMT = "sync_end_time:{sync_type}:{entity_id}:{sync_record_id}"
def _mark_metric_as_emitted(redis_std: Redis, key: str) -> None:
"""Mark a metric as having been emitted by setting a Redis key with expiration"""
@@ -126,7 +111,6 @@ class Metric(BaseModel):
}.items()
if v is not None
}
task_logger.info(f"Emitting metric: {data}")
optional_telemetry(
record_type=RecordType.METRIC,
data=data,
@@ -205,107 +189,48 @@ def _build_connector_start_latency_metric(
f"Start latency for index attempt {recent_attempt.id}: {start_latency:.2f}s "
f"(desired: {desired_start_time}, actual: {recent_attempt.time_started})"
)
job_id = build_job_id("connector", str(cc_pair.id), str(recent_attempt.id))
return Metric(
key=metric_key,
name="connector_start_latency",
value=start_latency,
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
tags={},
)
def _build_connector_final_metrics(
def _build_run_success_metrics(
cc_pair: ConnectorCredentialPair,
recent_attempts: list[IndexAttempt],
redis_std: Redis,
) -> list[Metric]:
"""
Final metrics for connector index attempts:
- Boolean success/fail metric
- If success, emit:
* duration (seconds)
* doc_count
"""
metrics = []
for attempt in recent_attempts:
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=attempt.id,
)
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.info(
f"Skipping final metrics for connector {cc_pair.connector.id} "
f"index attempt {attempt.id}, already emitted."
f"Skipping metric for connector {cc_pair.connector.id} "
f"index attempt {attempt.id} because it has already been "
"emitted"
)
continue
# We only emit final metrics if the attempt is in a terminal state
if attempt.status not in [
if attempt.status in [
IndexingStatus.SUCCESS,
IndexingStatus.FAILED,
IndexingStatus.CANCELED,
]:
# Not finished; skip
continue
job_id = build_job_id("connector", str(cc_pair.id), str(attempt.id))
success = attempt.status == IndexingStatus.SUCCESS
metrics.append(
Metric(
key=metric_key, # We'll mark the same key for any final metrics
name="connector_run_succeeded",
value=success,
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
"status": attempt.status.value,
},
task_logger.info(
f"Adding run success metric for index attempt {attempt.id} with status {attempt.status}"
)
)
if success:
# Make sure we have valid time_started
if attempt.time_started and attempt.time_updated:
duration_seconds = (
attempt.time_updated - attempt.time_started
).total_seconds()
metrics.append(
Metric(
key=None, # No need for a new key, or you can reuse the same if you prefer
name="connector_index_duration_seconds",
value=duration_seconds,
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
)
)
else:
task_logger.error(
f"Index attempt {attempt.id} succeeded but has missing time "
f"(time_started={attempt.time_started}, time_updated={attempt.time_updated})."
)
# For doc counts, choose whichever field is more relevant
doc_count = attempt.total_docs_indexed or 0
metrics.append(
Metric(
key=None,
name="connector_index_doc_count",
value=doc_count,
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
key=metric_key,
name="connector_run_succeeded",
value=attempt.status == IndexingStatus.SUCCESS,
tags={"source": str(cc_pair.connector.source)},
)
)
@@ -314,337 +239,189 @@ def _build_connector_final_metrics(
def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
"""Collect metrics about connector runs from the past hour"""
# NOTE: use get_db_current_time since the IndexAttempt times are set based on DB time
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
# Get all connector credential pairs
cc_pairs = db_session.scalars(select(ConnectorCredentialPair)).all()
# Might be more than one search setting, or just one
active_search_settings_list = get_active_search_settings_list(db_session)
active_search_settings = get_active_search_settings(db_session)
metrics = []
# If you want to process each cc_pair against each search setting:
for cc_pair in cc_pairs:
for search_settings in active_search_settings_list:
recent_attempts = (
db_session.query(IndexAttempt)
.filter(
IndexAttempt.connector_credential_pair_id == cc_pair.id,
IndexAttempt.search_settings_id == search_settings.id,
)
.order_by(IndexAttempt.time_created.desc())
.limit(2)
.all()
for cc_pair, search_settings in zip(cc_pairs, active_search_settings):
recent_attempts = (
db_session.query(IndexAttempt)
.filter(
IndexAttempt.connector_credential_pair_id == cc_pair.id,
IndexAttempt.search_settings_id == search_settings.id,
)
.order_by(IndexAttempt.time_created.desc())
.limit(2)
.all()
)
if not recent_attempts:
continue
if not recent_attempts:
continue
most_recent_attempt = recent_attempts[0]
second_most_recent_attempt = (
recent_attempts[1] if len(recent_attempts) > 1 else None
)
most_recent_attempt = recent_attempts[0]
second_most_recent_attempt = (
recent_attempts[1] if len(recent_attempts) > 1 else None
)
if one_hour_ago > most_recent_attempt.time_created:
continue
if one_hour_ago > most_recent_attempt.time_created:
continue
# Connector start latency
start_latency_metric = _build_connector_start_latency_metric(
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
)
if start_latency_metric:
metrics.append(start_latency_metric)
# Build a job_id for correlation
job_id = build_job_id(
"connector", str(cc_pair.id), str(most_recent_attempt.id)
)
# Add raw start time metric if available
if most_recent_attempt.time_started:
start_time_key = _CONNECTOR_START_TIME_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=most_recent_attempt.id,
)
metrics.append(
Metric(
key=start_time_key,
name="connector_start_time",
value=most_recent_attempt.time_started.timestamp(),
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
)
)
# Add raw end time metric if available and in terminal state
if (
most_recent_attempt.status.is_terminal()
and most_recent_attempt.time_updated
):
end_time_key = _CONNECTOR_END_TIME_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=most_recent_attempt.id,
)
metrics.append(
Metric(
key=end_time_key,
name="connector_end_time",
value=most_recent_attempt.time_updated.timestamp(),
tags={
"job_id": job_id,
"connector_id": str(cc_pair.connector.id),
"source": str(cc_pair.connector.source),
},
)
)
# Connector start latency
start_latency_metric = _build_connector_start_latency_metric(
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
)
if start_latency_metric:
metrics.append(start_latency_metric)
# Connector run success/failure
final_metrics = _build_connector_final_metrics(
cc_pair, recent_attempts, redis_std
)
metrics.extend(final_metrics)
# Connector run success/failure
run_success_metrics = _build_run_success_metrics(
cc_pair, recent_attempts, redis_std
)
metrics.extend(run_success_metrics)
return metrics
def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
"""
Collect metrics for document set and group syncing:
- Success/failure status
- Start latency (for doc sets / user groups)
- Duration & doc count (only if success)
- Throughput (docs/min) (only if success)
- Raw start/end times for each sync
"""
"""Collect metrics about document set and group syncing speed"""
# NOTE: use get_db_current_time since the SyncRecord times are set based on DB time
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
# Get all sync records that ended in the last hour
# Get all sync records from the last hour
recent_sync_records = db_session.scalars(
select(SyncRecord)
.where(SyncRecord.sync_end_time.isnot(None))
.where(SyncRecord.sync_end_time >= one_hour_ago)
.order_by(SyncRecord.sync_end_time.desc())
.where(SyncRecord.sync_start_time >= one_hour_ago)
.order_by(SyncRecord.sync_start_time.desc())
).all()
task_logger.info(
f"Collecting sync metrics for {len(recent_sync_records)} sync records"
)
metrics = []
for sync_record in recent_sync_records:
# Build a job_id for correlation
job_id = build_job_id("sync_record", str(sync_record.id))
# Skip if no end time (sync still in progress)
if not sync_record.sync_end_time:
continue
# Add raw start time metric
start_time_key = _SYNC_START_TIME_KEY_FMT.format(
sync_type=sync_record.sync_type,
entity_id=sync_record.entity_id,
sync_record_id=sync_record.id,
# Check if we already emitted a metric for this sync record
metric_key = (
f"sync_speed:{sync_record.sync_type}:"
f"{sync_record.entity_id}:{sync_record.id}"
)
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.info(
f"Skipping metric for sync record {sync_record.id} "
"because it has already been emitted"
)
continue
# Calculate sync duration in minutes
sync_duration_mins = (
sync_record.sync_end_time - sync_record.sync_start_time
).total_seconds() / 60.0
# Calculate sync speed (docs/min) - avoid division by zero
sync_speed = (
sync_record.num_docs_synced / sync_duration_mins
if sync_duration_mins > 0
else None
)
if sync_speed is None:
task_logger.error(
f"Something went wrong with sync speed calculation. "
f"Sync record: {sync_record.id}, duration: {sync_duration_mins}, "
f"docs synced: {sync_record.num_docs_synced}"
)
continue
task_logger.info(
f"Calculated sync speed for record {sync_record.id}: {sync_speed} docs/min"
)
metrics.append(
Metric(
key=start_time_key,
name="sync_start_time",
value=sync_record.sync_start_time.timestamp(),
key=metric_key,
name="sync_speed_docs_per_min",
value=sync_speed,
tags={
"sync_type": str(sync_record.sync_type),
"status": str(sync_record.sync_status),
},
)
)
# Add sync start latency metric
start_latency_key = (
f"sync_start_latency:{sync_record.sync_type}"
f":{sync_record.entity_id}:{sync_record.id}"
)
if _has_metric_been_emitted(redis_std, start_latency_key):
task_logger.info(
f"Skipping start latency metric for sync record {sync_record.id} "
"because it has already been emitted"
)
continue
# Get the entity's last update time based on sync type
entity: DocumentSet | UserGroup | None = None
if sync_record.sync_type == SyncType.DOCUMENT_SET:
entity = db_session.scalar(
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
)
elif sync_record.sync_type == SyncType.USER_GROUP:
entity = db_session.scalar(
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
)
else:
# Skip other sync types
task_logger.info(
f"Skipping sync record {sync_record.id} "
f"with type {sync_record.sync_type} "
f"and id {sync_record.entity_id} "
"because it is not a document set or user group"
)
continue
if entity is None:
task_logger.error(
f"Could not find entity for sync record {sync_record.id} "
f"with type {sync_record.sync_type} and id {sync_record.entity_id}"
)
continue
# Calculate start latency in seconds
start_latency = (
sync_record.sync_start_time - entity.time_last_modified_by_user
).total_seconds()
task_logger.info(
f"Calculated start latency for sync record {sync_record.id}: {start_latency} seconds"
)
if start_latency < 0:
task_logger.error(
f"Start latency is negative for sync record {sync_record.id} "
f"with type {sync_record.sync_type} and id {sync_record.entity_id}. "
f"Sync start time: {sync_record.sync_start_time}, "
f"Entity last modified: {entity.time_last_modified_by_user}"
)
continue
metrics.append(
Metric(
key=start_latency_key,
name="sync_start_latency_seconds",
value=start_latency,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
# Add raw end time metric if available
if sync_record.sync_end_time:
end_time_key = _SYNC_END_TIME_KEY_FMT.format(
sync_type=sync_record.sync_type,
entity_id=sync_record.entity_id,
sync_record_id=sync_record.id,
)
metrics.append(
Metric(
key=end_time_key,
name="sync_end_time",
value=sync_record.sync_end_time.timestamp(),
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
# Emit a SUCCESS/FAIL boolean metric
# Use a single Redis key to avoid re-emitting final metrics
final_metric_key = _FINAL_METRIC_KEY_FMT.format(
sync_type=sync_record.sync_type,
entity_id=sync_record.entity_id,
sync_record_id=sync_record.id,
)
if not _has_metric_been_emitted(redis_std, final_metric_key):
# Evaluate success
sync_succeeded = sync_record.sync_status == SyncStatus.SUCCESS
metrics.append(
Metric(
key=final_metric_key,
name="sync_run_succeeded",
value=sync_succeeded,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
"status": str(sync_record.sync_status),
},
)
)
# If successful, emit additional metrics
if sync_succeeded:
if sync_record.sync_end_time and sync_record.sync_start_time:
duration_seconds = (
sync_record.sync_end_time - sync_record.sync_start_time
).total_seconds()
else:
task_logger.error(
f"Invalid times for sync record {sync_record.id}: "
f"start={sync_record.sync_start_time}, end={sync_record.sync_end_time}"
)
duration_seconds = None
doc_count = sync_record.num_docs_synced or 0
sync_speed = None
if duration_seconds and duration_seconds > 0:
duration_mins = duration_seconds / 60.0
sync_speed = (
doc_count / duration_mins if duration_mins > 0 else None
)
# Emit duration, doc count, speed
if duration_seconds is not None:
metrics.append(
Metric(
key=final_metric_key,
name="sync_duration_seconds",
value=duration_seconds,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
else:
task_logger.error(
f"Invalid sync record {sync_record.id} with no duration"
)
metrics.append(
Metric(
key=final_metric_key,
name="sync_doc_count",
value=doc_count,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
if sync_speed is not None:
metrics.append(
Metric(
key=final_metric_key,
name="sync_speed_docs_per_min",
value=sync_speed,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
else:
task_logger.error(
f"Invalid sync record {sync_record.id} with no duration"
)
# Emit start latency
start_latency_key = _SYNC_START_LATENCY_KEY_FMT.format(
sync_type=sync_record.sync_type,
entity_id=sync_record.entity_id,
sync_record_id=sync_record.id,
)
if not _has_metric_been_emitted(redis_std, start_latency_key):
# Get the entity's last update time based on sync type
entity: DocumentSet | UserGroup | None = None
if sync_record.sync_type == SyncType.DOCUMENT_SET:
entity = db_session.scalar(
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
)
elif sync_record.sync_type == SyncType.USER_GROUP:
entity = db_session.scalar(
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
)
if entity is None:
task_logger.error(
f"Sync record of type {sync_record.sync_type} doesn't have an entity "
f"associated with it (id={sync_record.entity_id}). Skipping start latency metric."
)
# Calculate start latency in seconds:
# (actual sync start) - (last modified time)
if (
entity is not None
and entity.time_last_modified_by_user
and sync_record.sync_start_time
):
start_latency = (
sync_record.sync_start_time - entity.time_last_modified_by_user
).total_seconds()
if start_latency < 0:
task_logger.error(
f"Negative start latency for sync record {sync_record.id} "
f"(start={sync_record.sync_start_time}, entity_modified={entity.time_last_modified_by_user})"
)
continue
metrics.append(
Metric(
key=start_latency_key,
name="sync_start_latency_seconds",
value=start_latency,
tags={
"job_id": job_id,
"sync_type": str(sync_record.sync_type),
},
)
)
return metrics
def build_job_id(
job_type: Literal["connector", "sync_record"],
primary_id: str,
secondary_id: str | None = None,
) -> str:
if job_type == "connector":
if secondary_id is None:
raise ValueError(
"secondary_id (attempt_id) is required for connector job_type"
)
return f"connector:{primary_id}:attempt:{secondary_id}"
elif job_type == "sync_record":
return f"sync_record:{primary_id}"
@shared_task(
name=OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
ignore_result=True,
soft_time_limit=_MONITORING_SOFT_TIME_LIMIT,
time_limit=_MONITORING_TIME_LIMIT,
queue=OnyxCeleryQueues.MONITORING,
@@ -658,9 +435,6 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
- Syncing speed metrics
- Worker status and task counts
"""
if tenant_id is not None:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
task_logger.info("Starting background monitoring")
r = get_redis_client(tenant_id=tenant_id)
@@ -685,20 +459,14 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
lambda: _collect_connector_metrics(db_session, redis_std),
lambda: _collect_sync_metrics(db_session, redis_std),
]
# Collect and log each metric
with get_session_with_tenant(tenant_id) as db_session:
for metric_fn in metric_functions:
metrics = metric_fn()
for metric in metrics:
# double check to make sure we aren't double-emitting metrics
if metric.key is None or not _has_metric_been_emitted(
redis_std, metric.key
):
metric.log()
metric.emit(tenant_id)
if metric.key is not None:
metric.log()
metric.emit(tenant_id)
if metric.key:
_mark_metric_as_emitted(redis_std, metric.key)
task_logger.info("Successfully collected background metrics")

View File

@@ -25,30 +25,21 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import InputType
from onyx.db.connector import mark_ccpair_as_pruned
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_documents_for_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import pruning_ctx
from onyx.utils.logger import setup_logger
logger = setup_logger()
"""Jobs / utils for kicking off pruning tasks."""
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if pruning is due.
@@ -87,7 +78,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_PRUNING,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -213,14 +203,6 @@ def try_creating_prune_generator_task(
priority=OnyxCeleryPriority.LOW,
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=cc_pair.id,
sync_type=SyncType.PRUNING,
)
# set this only after all tasks have been added
redis_connector.prune.set_fence(True)
except Exception:
@@ -252,8 +234,6 @@ def connector_pruning_generator_task(
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
LoggerContextVars.reset()
pruning_ctx_dict = pruning_ctx.get()
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
pruning_ctx_dict["request_id"] = self.request.id
@@ -367,52 +347,3 @@ def connector_pruning_generator_task(
lock.release()
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")
"""Monitoring pruning utils, called in monitor_vespa_sync"""
def monitor_ccpair_pruning_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.prune.fenced:
return
initial = redis_connector.prune.generator_complete
if initial is None:
return
remaining = redis_connector.prune.get_remaining()
task_logger.info(
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Connector pruning finished: cc_pair={cc_pair_id} num_pruned={initial}"
)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.PRUNING,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial,
)
redis_connector.prune.taskset_clear()
redis_connector.prune.generator_clear()
redis_connector.prune.set_fence(False)

View File

@@ -27,14 +27,12 @@ from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_tenant
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
@@ -75,18 +73,14 @@ def document_by_cc_pair_cleanup_task(
"""
task_logger.debug(f"Task start: doc={document_id}")
start = time.monotonic()
try:
with get_session_with_tenant(tenant_id) as db_session:
action = "skip"
chunks_affected = 0
active_search_settings = get_active_search_settings(db_session)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
doc_index = get_default_document_index(
active_search_settings.primary,
active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
retry_index = RetryDocumentIndex(doc_index)
@@ -156,13 +150,11 @@ def document_by_cc_pair_cleanup_task(
db_session.commit()
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action={action} "
f"refcount={count} "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
f"chunks={chunks_affected}"
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
@@ -221,7 +213,6 @@ def document_by_cc_pair_cleanup_task(
@shared_task(
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
ignore_result=True,
trail=False,
bind=True,
)
@@ -256,10 +247,6 @@ def cloud_beat_task_generator(
lock_beat.reacquire()
last_lock_time = current_time
# needed in the cloud
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
continue
self.app.send_task(
task_name,
kwargs=dict(

View File

@@ -24,10 +24,6 @@ from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
monitor_ccpair_permissions_taskset,
)
from onyx.background.celery.tasks.pruning.tasks import monitor_ccpair_pruning_taskset
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
@@ -38,6 +34,8 @@ from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector import mark_ccpair_as_pruned
from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
@@ -63,22 +61,23 @@ from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.models import DocumentSet
from onyx.db.models import UserGroup
from onyx.db.search_settings import get_active_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
)
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_usergroup import RedisUserGroup
@@ -98,7 +97,6 @@ logger = setup_logger()
# which bloats the result metadata considerably. trail=False prevents this.
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -652,6 +650,83 @@ def monitor_connector_deletion_taskset(
redis_connector.delete.reset()
def monitor_ccpair_pruning_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.prune.fenced:
return
initial = redis_connector.prune.generator_complete
if initial is None:
return
remaining = redis_connector.prune.get_remaining()
task_logger.info(
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
)
redis_connector.prune.taskset_clear()
redis_connector.prune.generator_clear()
redis_connector.prune.set_fence(False)
def monitor_ccpair_permissions_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.permissions.fenced:
return
initial = redis_connector.permissions.generator_complete
if initial is None:
return
remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.reset()
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
@@ -796,12 +871,7 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
@shared_task(
name=OnyxCeleryTask.MONITOR_VESPA_SYNC,
ignore_result=True,
soft_time_limit=300,
bind=True,
)
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
"""This is a celery beat task that monitors and finalizes various long running tasks.
@@ -825,17 +895,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
# Replica usage notes
#
# False negatives are OK. (aka fail to to see a key that exists on the master).
# We simply skip the monitoring work and it will be caught on the next pass.
#
# False positives are not OK, and are possible if we clear a fence on the master and
# then read from the replica. In this case, monitoring work could be done on a fence
# that no longer exists. To avoid this, we scan from the replica, but double check
# the result on the master.
r_replica = get_redis_replica_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
@@ -895,19 +954,17 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
# scan and monitor activity to completion
phase_start = time.monotonic()
lock_beat.reacquire()
if r_replica.exists(RedisConnectorCredentialPair.get_fence_key()):
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
timings["connector"] = time.monotonic() - phase_start
timings["connector_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisConnectorDelete.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
lock_beat.reacquire()
timings["connector_deletion"] = time.monotonic() - phase_start
@@ -917,82 +974,70 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisDocumentSet.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["documentset"] = time.monotonic() - phase_start
timings["documentset_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisUserGroup.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
monitor_usergroup_taskset = (
fetch_versioned_implementation_with_fallback(
"onyx.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
"onyx.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["usergroup"] = time.monotonic() - phase_start
timings["usergroup_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["pruning"] = time.monotonic() - phase_start
timings["pruning_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
if r.exists(key_bytes):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["indexing"] = time.monotonic() - phase_start
timings["indexing_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r_replica.scan_iter(
for key_bytes in r.scan_iter(
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
if r.exists(key_bytes):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(
tenant_id, key_bytes, r, db_session
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["permissions"] = time.monotonic() - phase_start
timings["permissions_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
return False
except Exception:
task_logger.exception("monitor_vespa_sync exceptioned.")
return False
finally:
if lock_beat.owned():
lock_beat.release()
@@ -1023,11 +1068,9 @@ def vespa_metadata_sync_task(
try:
with get_session_with_tenant(tenant_id) as db_session:
active_search_settings = get_active_search_settings(db_session)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
doc_index = get_default_document_index(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
retry_index = RetryDocumentIndex(doc_index)
@@ -1082,7 +1125,6 @@ def vespa_metadata_sync_task(
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.warning(

View File

@@ -35,7 +35,6 @@ from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.document_index.factory import get_default_document_index
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
@@ -220,10 +219,9 @@ def _run_indexing(
callback=callback,
)
# Indexing is only done into one index at a time
document_index = get_default_document_index(
index_attempt_start.search_settings,
None,
httpx_client=HttpxPool.get("vespa"),
primary_index_name=ctx.index_name, secondary_index_name=None
)
indexing_pipeline = build_indexing_pipeline(

View File

@@ -254,7 +254,6 @@ def _get_force_search_settings(
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
new_msg_req.search_doc_ids,
new_msg_req.query_override is not None,
DISABLE_LLM_CHOOSE_SEARCH,
]
)
@@ -426,7 +425,9 @@ def stream_chat_message_objects(
)
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(search_settings, None)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
)
# Every chat Session begins with an empty root message
root_message = get_or_create_root_message(
@@ -498,6 +499,14 @@ def stream_chat_message_objects(
f"existing assistant message id: {existing_assistant_message_id}"
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
# leads to worst search quality
if not history_msgs:
new_msg_req.query_override = (
new_msg_req.query_override or new_msg_req.message
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session

View File

@@ -200,8 +200,6 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
# this assumes that other redis settings remain the same as the primary
REDIS_REPLICA_HOST = os.environ.get("REDIS_REPLICA_HOST") or REDIS_HOST
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
@@ -478,12 +476,6 @@ INDEXING_SIZE_WARNING_THRESHOLD = int(
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)
# Enable multi-threaded embedding model calls for parallel processing
# Note: only applies for API-based embedding models
INDEXING_EMBEDDING_MODEL_NUM_THREADS = int(
os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 1
)
# During an indexing attempt, specifies the number of batches which are allowed to
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
@@ -617,8 +609,3 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
# Set to true to mock LLM responses for testing purposes
MOCK_LLM_RESPONSE = (
os.environ.get("MOCK_LLM_RESPONSE") if os.environ.get("MOCK_LLM_RESPONSE") else None
)

View File

@@ -300,8 +300,6 @@ class OnyxRedisLocks:
class OnyxRedisSignals:
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
VALIDATE_EXTERNAL_GROUP_SYNC_FENCES = "signal:validate_external_group_sync_fences"
VALIDATE_PERMISSION_SYNC_FENCES = "signal:validate_permission_sync_fences"
class OnyxCeleryPriority(int, Enum):

View File

@@ -1,7 +1,3 @@
import contextvars
from concurrent.futures import as_completed
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import Any
@@ -24,9 +20,9 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# NOTE: all are made lowercase to avoid case sensitivity issues
# These field types are considered metadata by default when
# treat_all_non_attachment_fields_as_metadata is False
DEFAULT_METADATA_FIELD_TYPES = {
# these are the field types that are considered metadata rather
# than sections
_METADATA_FIELD_TYPES = {
"singlecollaborator",
"collaborator",
"createdby",
@@ -64,31 +60,20 @@ class AirtableConnector(LoadConnector):
self,
base_id: str,
table_name_or_id: str,
treat_all_non_attachment_fields_as_metadata: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.base_id = base_id
self.table_name_or_id = table_name_or_id
self.batch_size = batch_size
self._airtable_client: AirtableApi | None = None
self.treat_all_non_attachment_fields_as_metadata = (
treat_all_non_attachment_fields_as_metadata
)
self.airtable_client: AirtableApi | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._airtable_client = AirtableApi(credentials["airtable_access_token"])
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
return None
@property
def airtable_client(self) -> AirtableApi:
if not self._airtable_client:
raise AirtableClientNotSetUpError()
return self._airtable_client
@staticmethod
def _extract_field_values(
self,
field_id: str,
field_name: str,
field_info: Any,
field_type: str,
base_id: str,
@@ -127,33 +112,13 @@ class AirtableConnector(LoadConnector):
backoff=2,
max_delay=10,
)
def get_attachment_with_retry(url: str, record_id: str) -> bytes | None:
try:
attachment_response = requests.get(url)
attachment_response.raise_for_status()
def get_attachment_with_retry(url: str) -> bytes | None:
attachment_response = requests.get(url)
if attachment_response.status_code == 200:
return attachment_response.content
except requests.exceptions.HTTPError as e:
if e.response.status_code == 410:
logger.info(f"Refreshing attachment for {filename}")
# Re-fetch the record to get a fresh URL
refreshed_record = self.airtable_client.table(
base_id, table_id
).get(record_id)
for refreshed_attachment in refreshed_record["fields"][
field_name
]:
if refreshed_attachment.get("filename") == filename:
new_url = refreshed_attachment.get("url")
if new_url:
attachment_response = requests.get(new_url)
attachment_response.raise_for_status()
return attachment_response.content
return None
logger.error(f"Failed to refresh attachment for {filename}")
raise
attachment_content = get_attachment_with_retry(url, record_id)
attachment_content = get_attachment_with_retry(url)
if attachment_content:
try:
file_ext = get_file_ext(filename)
@@ -201,14 +166,8 @@ class AirtableConnector(LoadConnector):
return [(str(field_info), default_link)]
def _should_be_metadata(self, field_type: str) -> bool:
"""Determine if a field type should be treated as metadata.
When treat_all_non_attachment_fields_as_metadata is True, all fields except
attachments are treated as metadata. Otherwise, only fields with types listed
in DEFAULT_METADATA_FIELD_TYPES are treated as metadata."""
if self.treat_all_non_attachment_fields_as_metadata:
return field_type.lower() != "multipleattachments"
return field_type.lower() in DEFAULT_METADATA_FIELD_TYPES
"""Determine if a field type should be treated as metadata."""
return field_type.lower() in _METADATA_FIELD_TYPES
def _process_field(
self,
@@ -237,7 +196,6 @@ class AirtableConnector(LoadConnector):
# Get the value(s) for the field
field_value_and_links = self._extract_field_values(
field_id=field_id,
field_name=field_name,
field_info=field_info,
field_type=field_type,
base_id=self.base_id,
@@ -275,7 +233,7 @@ class AirtableConnector(LoadConnector):
record: RecordDict,
table_schema: TableSchema,
primary_field_name: str | None,
) -> Document | None:
) -> Document:
"""Process a single Airtable record into a Document.
Args:
@@ -306,11 +264,6 @@ class AirtableConnector(LoadConnector):
field_val = fields.get(field_name)
field_type = field_schema.type
logger.debug(
f"Processing field '{field_name}' of type '{field_type}' "
f"for record '{record_id}'."
)
field_sections, field_metadata = self._process_field(
field_id=field_schema.id,
field_name=field_name,
@@ -324,10 +277,6 @@ class AirtableConnector(LoadConnector):
sections.extend(field_sections)
metadata.update(field_metadata)
if not sections:
logger.warning(f"No sections found for record {record_id}")
return None
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
@@ -364,47 +313,18 @@ class AirtableConnector(LoadConnector):
primary_field_name = field.name
break
logger.info(f"Starting to process Airtable records for {table.name}.")
record_documents: list[Document] = []
for record in records:
document = self._process_record(
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
)
record_documents.append(document)
# Process records in parallel batches using ThreadPoolExecutor
PARALLEL_BATCH_SIZE = 8
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
if len(record_documents) >= self.batch_size:
yield record_documents
record_documents = []
# Process records in batches
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
record_documents: list[Document] = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit batch tasks
future_to_record: dict[Future, RecordDict] = {}
for record in batch_records:
# Capture the current context so that the thread gets the current tenant ID
current_context = contextvars.copy_context()
future_to_record[
executor.submit(
current_context.run,
self._process_record,
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
)
] = record
# Wait for all tasks in this batch to complete
for future in as_completed(future_to_record):
record = future_to_record[future]
try:
document = future.result()
if document:
record_documents.append(document)
except Exception as e:
logger.exception(f"Failed to process record {record['id']}")
raise e
yield record_documents
record_documents = []
# Yield any remaining records
if record_documents:
yield record_documents

View File

@@ -232,29 +232,20 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
}
# Get labels
label_dicts = (
confluence_object.get("metadata", {}).get("labels", {}).get("results", [])
)
page_labels = [label.get("name") for label in label_dicts if label.get("name")]
label_dicts = confluence_object["metadata"]["labels"]["results"]
page_labels = [label["name"] for label in label_dicts]
if page_labels:
doc_metadata["labels"] = page_labels
# Get last modified and author email
version_dict = confluence_object.get("version", {})
last_modified = (
datetime_from_string(version_dict.get("when"))
if version_dict.get("when")
else None
)
author_email = version_dict.get("by", {}).get("email")
title = confluence_object.get("title", "Untitled Document")
last_modified = datetime_from_string(confluence_object["version"]["when"])
author_email = confluence_object["version"].get("by", {}).get("email")
return Document(
id=object_url,
sections=[Section(link=object_url, text=object_text)],
source=DocumentSource.CONFLUENCE,
semantic_identifier=title,
semantic_identifier=confluence_object["title"],
doc_updated_at=last_modified,
primary_owners=(
[BasicExpertInfo(email=author_email)] if author_email else None

View File

@@ -1,5 +1,4 @@
import sys
import time
from datetime import datetime
from onyx.connectors.interfaces import BaseConnector
@@ -46,17 +45,7 @@ class ConnectorRunner:
def run(self) -> GenerateDocumentsOutput:
"""Adds additional exception logging to the connector."""
try:
start = time.monotonic()
for batch in self.doc_batch_generator:
# to know how long connector is taking
logger.debug(
f"Connector took {time.monotonic() - start} seconds to build a batch."
)
yield batch
start = time.monotonic()
yield from self.doc_batch_generator
except Exception:
exc_type, _, exc_traceback = sys.exc_info()

View File

@@ -50,9 +50,6 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
current_link = ""
current_text = ""
if transcript["sentences"] is None:
return None
for sentence in transcript["sentences"]:
if sentence["speaker_name"] != current_speaker_name:
if current_speaker_name is not None:

View File

@@ -150,16 +150,6 @@ class Document(DocumentBase):
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
source: DocumentSource
def get_total_char_length(self) -> int:
"""Calculate the total character length of the document including sections, metadata, and identifiers."""
section_length = sum(len(section.text) for section in self.sections)
identifier_length = len(self.semantic_identifier) + len(self.title or "")
metadata_length = sum(
len(k) + len(v) if isinstance(v, str) else len(k) + sum(len(x) for x in v)
for k, v in self.metadata.items()
)
return section_length + identifier_length + metadata_length
def to_short_descriptor(self) -> str:
"""Used when logging the identity of a document"""
return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'"

View File

@@ -1,14 +1,16 @@
import io
import os
from dataclasses import dataclass
from dataclasses import field
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import unquote
from typing import Optional
import msal # type: ignore
from office365.graph_client import GraphClient # type: ignore
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
from pydantic import BaseModel
from office365.onedrive.sites.site import Site # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
@@ -27,25 +29,16 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
class SiteDescriptor(BaseModel):
"""Data class for storing SharePoint site information.
Args:
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests)
drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library")
If None, all drives will be accessed.
folder_path: The folder path within the drive to access (e.g. "test/nested with spaces")
If None, all folders will be accessed.
"""
url: str
drive_name: str | None
folder_path: str | None
@dataclass
class SiteData:
url: str | None
folder: Optional[str]
sites: list = field(default_factory=list)
driveitems: list = field(default_factory=list)
def _convert_driveitem_to_document(
driveitem: DriveItem,
drive_name: str,
) -> Document:
file_text = extract_file_text(
file=io.BytesIO(driveitem.get_content().execute_query().value),
@@ -65,7 +58,7 @@ def _convert_driveitem_to_document(
email=driveitem.last_modified_by.user.email,
)
],
metadata={"drive": drive_name},
metadata={},
)
return doc
@@ -77,179 +70,93 @@ class SharepointConnector(LoadConnector, PollConnector):
sites: list[str] = [],
) -> None:
self.batch_size = batch_size
self._graph_client: GraphClient | None = None
self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info(
sites
)
self.msal_app: msal.ConfidentialClientApplication | None = None
@property
def graph_client(self) -> GraphClient:
if self._graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
return self._graph_client
self.graph_client: GraphClient | None = None
self.site_data: list[SiteData] = self._extract_site_and_folder(sites)
@staticmethod
def _extract_site_and_drive_info(site_urls: list[str]) -> list[SiteDescriptor]:
def _extract_site_and_folder(site_urls: list[str]) -> list[SiteData]:
site_data_list = []
for url in site_urls:
parts = url.strip().split("/")
if "sites" in parts:
sites_index = parts.index("sites")
site_url = "/".join(parts[: sites_index + 2])
remaining_parts = parts[sites_index + 2 :]
# Extract drive name and folder path
if remaining_parts:
drive_name = unquote(remaining_parts[0])
folder_path = (
"/".join(unquote(part) for part in remaining_parts[1:])
if len(remaining_parts) > 1
else None
)
else:
drive_name = None
folder_path = None
folder = (
parts[sites_index + 2] if len(parts) > sites_index + 2 else None
)
site_data_list.append(
SiteDescriptor(
url=site_url,
drive_name=drive_name,
folder_path=folder_path,
)
SiteData(url=site_url, folder=folder, sites=[], driveitems=[])
)
return site_data_list
def _fetch_driveitems(
def _populate_sitedata_driveitems(
self,
site_descriptor: SiteDescriptor,
start: datetime | None = None,
end: datetime | None = None,
) -> list[tuple[DriveItem, str]]:
final_driveitems: list[tuple[DriveItem, str]] = []
try:
site = self.graph_client.sites.get_by_url(site_descriptor.url)
) -> None:
filter_str = ""
if start is not None and end is not None:
filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}"
# Get all drives in the site
drives = site.drives.get().execute_query()
logger.debug(f"Found drives: {[drive.name for drive in drives]}")
for element in self.site_data:
sites: list[Site] = []
for site in element.sites:
site_sublist = site.lists.get().execute_query()
sites.extend(site_sublist)
# Filter drives based on the requested drive name
if site_descriptor.drive_name:
drives = [
drive
for drive in drives
if drive.name == site_descriptor.drive_name
or (
drive.name == "Documents"
and site_descriptor.drive_name == "Shared Documents"
)
]
if not drives:
logger.warning(f"Drive '{site_descriptor.drive_name}' not found")
return []
# Process each matching drive
for drive in drives:
for site in sites:
try:
root_folder = drive.root
if site_descriptor.folder_path:
# If a specific folder is requested, navigate to it
for folder_part in site_descriptor.folder_path.split("/"):
root_folder = root_folder.get_by_path(folder_part)
# Get all items recursively
query = root_folder.get_files(
recursive=True,
page_size=1000,
)
query = site.drive.root.get_files(True, 1000)
if filter_str:
query = query.filter(filter_str)
driveitems = query.execute_query()
logger.debug(
f"Found {len(driveitems)} items in drive '{drive.name}'"
)
# Use "Shared Documents" as the library name for the default "Documents" drive
drive_name = (
"Shared Documents" if drive.name == "Documents" else drive.name
)
# Filter items based on folder path if specified
if site_descriptor.folder_path:
# Filter items to ensure they're in the specified folder or its subfolders
# The path will be in format: /drives/{drive_id}/root:/folder/path
driveitems = [
if element.folder:
filtered_driveitems = [
item
for item in driveitems
if any(
path_part == site_descriptor.folder_path
or path_part.startswith(
site_descriptor.folder_path + "/"
)
for path_part in item.parent_reference.path.split(
"root:/"
)[1].split("/")
)
if element.folder in item.parent_reference.path
]
if len(driveitems) == 0:
all_paths = [
item.parent_reference.path for item in driveitems
]
logger.warning(
f"Nothing found for folder '{site_descriptor.folder_path}' "
f"in; any of valid paths: {all_paths}"
)
element.driveitems.extend(filtered_driveitems)
else:
element.driveitems.extend(driveitems)
# Filter items based on time window if specified
if start is not None and end is not None:
driveitems = [
item
for item in driveitems
if start
<= item.last_modified_datetime.replace(tzinfo=timezone.utc)
<= end
]
logger.debug(
f"Found {len(driveitems)} items within time window in drive '{drive.name}'"
)
except Exception:
# Sites include things that do not contain .drive.root so this fails
# but this is fine, as there are no actually documents in those
pass
for item in driveitems:
final_driveitems.append((item, drive_name))
def _populate_sitedata_sites(self) -> None:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
except Exception as e:
# Some drives might not be accessible
logger.warning(f"Failed to process drive: {str(e)}")
except Exception as e:
# Sites include things that do not contain drives so this fails
# but this is fine, as there are no actual documents in those
logger.warning(f"Failed to process site: {str(e)}")
return final_driveitems
def _fetch_sites(self) -> list[SiteDescriptor]:
sites = self.graph_client.sites.get_all().execute_query()
site_descriptors = [
SiteDescriptor(
url=sites.resource_url,
drive_name=None,
folder_path=None,
)
]
return site_descriptors
if self.site_data:
for element in self.site_data:
element.sites = [
self.graph_client.sites.get_by_url(element.url)
.get()
.execute_query()
]
else:
sites = self.graph_client.sites.get_all().execute_query()
self.site_data = [
SiteData(url=None, folder=None, sites=sites, driveitems=[])
]
def _fetch_from_sharepoint(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
site_descriptors = self.site_descriptors or self._fetch_sites()
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
self._populate_sitedata_sites()
self._populate_sitedata_driveitems(start=start, end=end)
# goes over all urls, converts them into Document objects and then yields them in batches
doc_batch: list[Document] = []
for site_descriptor in site_descriptors:
driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end)
for driveitem, drive_name in driveitems:
for element in self.site_data:
for driveitem in element.driveitems:
logger.debug(f"Processing: {driveitem.web_url}")
doc_batch.append(_convert_driveitem_to_document(driveitem, drive_name))
doc_batch.append(_convert_driveitem_to_document(driveitem))
if len(doc_batch) >= self.batch_size:
yield doc_batch
@@ -261,26 +168,22 @@ class SharepointConnector(LoadConnector, PollConnector):
sp_client_secret = credentials["sp_client_secret"]
sp_directory_id = credentials["sp_directory_id"]
authority_url = f"https://login.microsoftonline.com/{sp_directory_id}"
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=sp_client_id,
client_credential=sp_client_secret,
)
def _acquire_token_func() -> dict[str, Any]:
"""
Acquire token via MSAL
"""
if self.msal_app is None:
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
authority_url = f"https://login.microsoftonline.com/{sp_directory_id}"
app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=sp_client_id,
client_credential=sp_client_secret,
)
token = app.acquire_token_for_client(
scopes=["https://graph.microsoft.com/.default"]
)
return token
self._graph_client = GraphClient(_acquire_token_func)
self.graph_client = GraphClient(_acquire_token_func)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
@@ -289,19 +192,19 @@ class SharepointConnector(LoadConnector, PollConnector):
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(start, timezone.utc)
end_datetime = datetime.fromtimestamp(end, timezone.utc)
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime)
if __name__ == "__main__":
connector = SharepointConnector(sites=os.environ["SHAREPOINT_SITES"].split(","))
connector = SharepointConnector(sites=os.environ["SITES"].split(","))
connector.load_credentials(
{
"sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
"sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
"sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
"sp_client_id": os.environ["SP_CLIENT_ID"],
"sp_client_secret": os.environ["SP_CLIENT_SECRET"],
"sp_directory_id": os.environ["SP_CLIENT_DIRECTORY_ID"],
}
)
document_batches = connector.load_from_state()

View File

@@ -104,11 +104,8 @@ def make_slack_api_rate_limited(
f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
)
time.sleep(retry_after)
elif error in ["already_reacted", "no_reaction", "internal_error"]:
# Log internal_error and return the response instead of failing
logger.warning(
f"Slack call encountered '{error}', skipping and continuing..."
)
elif error in ["already_reacted", "no_reaction"]:
# The response isn't used for reactions, this is basically just a pass
return e.response
else:
# Raise the error for non-transient errors

View File

@@ -180,28 +180,23 @@ class TeamsConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self.graph_client: GraphClient | None = None
self.requested_team_list: list[str] = teams
self.msal_app: msal.ConfidentialClientApplication | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
teams_client_id = credentials["teams_client_id"]
teams_client_secret = credentials["teams_client_secret"]
teams_directory_id = credentials["teams_directory_id"]
authority_url = f"https://login.microsoftonline.com/{teams_directory_id}"
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=teams_client_id,
client_credential=teams_client_secret,
)
def _acquire_token_func() -> dict[str, Any]:
"""
Acquire token via MSAL
"""
if self.msal_app is None:
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
authority_url = f"https://login.microsoftonline.com/{teams_directory_id}"
app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=teams_client_id,
client_credential=teams_client_secret,
)
token = app.acquire_token_for_client(
scopes=["https://graph.microsoft.com/.default"]
)
return token

View File

@@ -67,7 +67,10 @@ class SearchPipeline:
self.rerank_metrics_callback = rerank_metrics_callback
self.search_settings = get_current_search_settings(db_session)
self.document_index = get_default_document_index(self.search_settings, None)
self.document_index = get_default_document_index(
primary_index_name=self.search_settings.index_name,
secondary_index_name=None,
)
self.prompt_config: PromptConfig | None = prompt_config
# Preprocessing steps generate this

View File

@@ -28,9 +28,6 @@ class SyncType(str, PyEnum):
DOCUMENT_SET = "document_set"
USER_GROUP = "user_group"
CONNECTOR_DELETION = "connector_deletion"
PRUNING = "pruning" # not really a sync, but close enough
EXTERNAL_PERMISSIONS = "external_permissions"
EXTERNAL_GROUP = "external_group"
def __str__(self) -> str:
return self.value

View File

@@ -3,8 +3,6 @@ from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.constants import AuthType
from onyx.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from onyx.db.models import DocumentSet
from onyx.db.models import LLMProvider as LLMProviderModel
@@ -126,29 +124,10 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
def fetch_existing_llm_providers(
db_session: Session,
) -> list[LLMProviderModel]:
stmt = select(LLMProviderModel)
return list(db_session.scalars(stmt).all())
def fetch_existing_llm_providers_for_user(
db_session: Session,
user: User | None = None,
) -> list[LLMProviderModel]:
if not user:
if AUTH_TYPE != AuthType.DISABLED:
# User is anonymous
return list(
db_session.scalars(
select(LLMProviderModel).where(
LLMProviderModel.is_public == True # noqa: E712
)
).all()
)
else:
# If auth is disabled, user has access to all providers
return fetch_existing_llm_providers(db_session)
return list(db_session.scalars(select(LLMProviderModel)).all())
stmt = select(LLMProviderModel).distinct()
user_groups_select = select(User__UserGroup.user_group_id).where(
User__UserGroup.user_id == user.id

View File

@@ -150,7 +150,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
temperature_override_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
chosen_assistants: Mapped[list[int] | None] = mapped_column(
@@ -162,7 +161,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
hidden_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[]
)
recent_assistants: Mapped[list[dict]] = mapped_column(
postgresql.JSONB(), nullable=False, default=list, server_default="[]"
)
pinned_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)
@@ -746,34 +747,6 @@ class SearchSettings(Base):
def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider is not None else None
@property
def large_chunks_enabled(self) -> bool:
"""
Given multipass usage and an embedder, decides whether large chunks are allowed
based on model/provider constraints.
"""
# Only local models that support a larger context are from Nomic
# Cohere does not support larger contexts (they recommend not going above ~512 tokens)
return SearchSettings.can_use_large_chunks(
self.multipass_indexing, self.model_name, self.provider_type
)
@staticmethod
def can_use_large_chunks(
multipass: bool, model_name: str, provider_type: EmbeddingProvider | None
) -> bool:
"""
Given multipass usage and an embedder, decides whether large chunks are allowed
based on model/provider constraints.
"""
# Only local models that support a larger context are from Nomic
# Cohere does not support larger contexts (they recommend not going above ~512 tokens)
return (
multipass
and model_name.startswith("nomic-ai")
and provider_type != EmbeddingProvider.COHERE
)
class IndexAttempt(Base):
"""
@@ -1116,10 +1089,6 @@ class ChatSession(Base):
llm_override: Mapped[LLMOverride | None] = mapped_column(
PydanticType(LLMOverride), nullable=True
)
# The latest temperature override specified by the user
temperature_override: Mapped[float | None] = mapped_column(Float, nullable=True)
prompt_override: Mapped[PromptOverride | None] = mapped_column(
PydanticType(PromptOverride), nullable=True
)

View File

@@ -11,7 +11,7 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import aliased
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
@@ -291,9 +291,8 @@ def get_personas_for_user(
include_deleted: bool = False,
joinedload_all: bool = False,
) -> Sequence[Persona]:
stmt = select(Persona)
stmt = _add_user_filters(stmt, user, get_editable)
stmt = select(Persona).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
if not include_default:
stmt = stmt.where(Persona.builtin_persona.is_(False))
if not include_slack_bot_personas:
@@ -303,16 +302,14 @@ def get_personas_for_user(
if joinedload_all:
stmt = stmt.options(
selectinload(Persona.prompts),
selectinload(Persona.tools),
selectinload(Persona.document_sets),
selectinload(Persona.groups),
selectinload(Persona.users),
selectinload(Persona.labels),
joinedload(Persona.prompts),
joinedload(Persona.tools),
joinedload(Persona.document_sets),
joinedload(Persona.groups),
joinedload(Persona.users),
)
results = db_session.execute(stmt).scalars().all()
return results
return db_session.execute(stmt).unique().scalars().all()
def get_personas(db_session: Session) -> Sequence[Persona]:

View File

@@ -29,21 +29,9 @@ from onyx.utils.logger import setup_logger
from shared_configs.configs import PRESERVED_SEARCH_FIELDS
from shared_configs.enums import EmbeddingProvider
logger = setup_logger()
class ActiveSearchSettings:
primary: SearchSettings
secondary: SearchSettings | None
def __init__(
self, primary: SearchSettings, secondary: SearchSettings | None
) -> None:
self.primary = primary
self.secondary = secondary
def create_search_settings(
search_settings: SavedSearchSettings,
db_session: Session,
@@ -155,27 +143,21 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
return latest_settings
def get_active_search_settings(db_session: Session) -> ActiveSearchSettings:
"""Returns active search settings. Secondary search settings may be None."""
# Get the primary and secondary search settings
primary_search_settings = get_current_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
return ActiveSearchSettings(
primary=primary_search_settings, secondary=secondary_search_settings
)
def get_active_search_settings_list(db_session: Session) -> list[SearchSettings]:
"""Returns active search settings as a list. Primary settings are the first element,
and if secondary search settings exist, they will be the second element."""
def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
"""Returns active search settings. The first entry will always be the current search
settings. If there are new search settings that are being migrated to, those will be
the second entry."""
search_settings_list: list[SearchSettings] = []
active_search_settings = get_active_search_settings(db_session)
search_settings_list.append(active_search_settings.primary)
if active_search_settings.secondary:
search_settings_list.append(active_search_settings.secondary)
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings_list.append(primary_search_settings)
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings_list.append(secondary_search_settings)
return search_settings_list

View File

@@ -8,64 +8,20 @@ from sqlalchemy.orm import Session
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import SyncRecord
from onyx.setup import setup_logger
logger = setup_logger()
def insert_sync_record(
db_session: Session,
entity_id: int,
entity_id: int | None,
sync_type: SyncType,
) -> SyncRecord:
"""Insert a new sync record into the database, cancelling any existing in-progress records.
"""Insert a new sync record into the database.
Args:
db_session: The database session to use
entity_id: The ID of the entity being synced (document set ID, user group ID, etc.)
sync_type: The type of sync operation
"""
# If an existing in-progress sync record exists, mark as cancelled
existing_in_progress_sync_record = fetch_latest_sync_record(
db_session, entity_id, sync_type, sync_status=SyncStatus.IN_PROGRESS
)
if existing_in_progress_sync_record is not None:
logger.info(
f"Cancelling existing in-progress sync record {existing_in_progress_sync_record.id} "
f"for entity_id={entity_id} sync_type={sync_type}"
)
mark_sync_records_as_cancelled(db_session, entity_id, sync_type)
return _create_sync_record(db_session, entity_id, sync_type)
def mark_sync_records_as_cancelled(
db_session: Session,
entity_id: int | None,
sync_type: SyncType,
) -> None:
stmt = (
update(SyncRecord)
.where(
and_(
SyncRecord.entity_id == entity_id,
SyncRecord.sync_type == sync_type,
SyncRecord.sync_status == SyncStatus.IN_PROGRESS,
)
)
.values(sync_status=SyncStatus.CANCELED)
)
db_session.execute(stmt)
db_session.commit()
def _create_sync_record(
db_session: Session,
entity_id: int | None,
sync_type: SyncType,
) -> SyncRecord:
"""Create and insert a new sync record into the database."""
sync_record = SyncRecord(
entity_id=entity_id,
sync_type=sync_type,
@@ -83,7 +39,6 @@ def fetch_latest_sync_record(
db_session: Session,
entity_id: int,
sync_type: SyncType,
sync_status: SyncStatus | None = None,
) -> SyncRecord | None:
"""Fetch the most recent sync record for a given entity ID and status.
@@ -104,9 +59,6 @@ def fetch_latest_sync_record(
.limit(1)
)
if sync_status is not None:
stmt = stmt.where(SyncRecord.sync_status == sync_status)
result = db_session.execute(stmt)
return result.scalar_one_or_none()

View File

@@ -4,63 +4,24 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.configs.app_configs import ENABLE_MULTIPASS_INDEXING
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import MultipassConfig
from shared_configs.configs import MULTI_TENANT
DEFAULT_BATCH_SIZE = 30
DEFAULT_INDEX_NAME = "danswer_chunk"
def should_use_multipass(search_settings: SearchSettings | None) -> bool:
"""
Determines whether multipass should be used based on the search settings
or the default config if settings are unavailable.
"""
if search_settings is not None:
return search_settings.multipass_indexing
return ENABLE_MULTIPASS_INDEXING
def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig:
"""
Determines whether to enable multipass and large chunks by examining
the current search settings and the embedder configuration.
"""
if not search_settings:
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
multipass = should_use_multipass(search_settings)
enable_large_chunks = SearchSettings.can_use_large_chunks(
multipass, search_settings.model_name, search_settings.provider_type
)
return MultipassConfig(
multipass_indexing=multipass, enable_large_chunks=enable_large_chunks
)
def get_both_index_properties(
db_session: Session,
) -> tuple[str, str | None, bool, bool | None]:
def get_both_index_names(db_session: Session) -> tuple[str, str | None]:
search_settings = get_current_search_settings(db_session)
config_1 = get_multipass_config(search_settings)
search_settings_new = get_secondary_search_settings(db_session)
if not search_settings_new:
return search_settings.index_name, None, config_1.enable_large_chunks, None
return search_settings.index_name, None
config_2 = get_multipass_config(search_settings)
return (
search_settings.index_name,
search_settings_new.index_name,
config_1.enable_large_chunks,
config_2.enable_large_chunks,
)
return search_settings.index_name, search_settings_new.index_name
def translate_boost_count_to_multiplier(boost: int) -> float:

View File

@@ -1,7 +1,5 @@
import httpx
from sqlalchemy.orm import Session
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.vespa.index import VespaIndex
@@ -9,28 +7,17 @@ from shared_configs.configs import MULTI_TENANT
def get_default_document_index(
search_settings: SearchSettings,
secondary_search_settings: SearchSettings | None,
httpx_client: httpx.Client | None = None,
primary_index_name: str,
secondary_index_name: str | None,
) -> DocumentIndex:
"""Primary index is the index that is used for querying/updating etc.
Secondary index is for when both the currently used index and the upcoming
index both need to be updated, updates are applied to both indices"""
secondary_index_name: str | None = None
secondary_large_chunks_enabled: bool | None = None
if secondary_search_settings:
secondary_index_name = secondary_search_settings.index_name
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
# Currently only supporting Vespa
return VespaIndex(
index_name=search_settings.index_name,
index_name=primary_index_name,
secondary_index_name=secondary_index_name,
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)
@@ -40,6 +27,6 @@ def get_current_primary_default_document_index(db_session: Session) -> DocumentI
"""
search_settings = get_current_search_settings(db_session)
return get_default_document_index(
search_settings,
None,
primary_index_name=search_settings.index_name,
secondary_index_name=None,
)

View File

@@ -231,22 +231,21 @@ def _get_chunks_via_visit_api(
return document_chunks
# TODO(rkuo): candidate for removal if not being used
# @retry(tries=10, delay=1, backoff=2)
# def get_all_vespa_ids_for_document_id(
# document_id: str,
# index_name: str,
# filters: IndexFilters | None = None,
# get_large_chunks: bool = False,
# ) -> list[str]:
# document_chunks = _get_chunks_via_visit_api(
# chunk_request=VespaChunkRequest(document_id=document_id),
# index_name=index_name,
# filters=filters or IndexFilters(access_control_list=None),
# field_names=[DOCUMENT_ID],
# get_large_chunks=get_large_chunks,
# )
# return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks]
@retry(tries=10, delay=1, backoff=2)
def get_all_vespa_ids_for_document_id(
document_id: str,
index_name: str,
filters: IndexFilters | None = None,
get_large_chunks: bool = False,
) -> list[str]:
document_chunks = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(document_id=document_id),
index_name=index_name,
filters=filters or IndexFilters(access_control_list=None),
field_names=[DOCUMENT_ID],
get_large_chunks=get_large_chunks,
)
return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks]
def parallel_visit_api_retrieval(

View File

@@ -25,6 +25,7 @@ from onyx.configs.chat_configs import VESPA_SEARCHER_THREADS
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.db.engine import get_session_with_tenant
from onyx.document_index.document_index_utils import get_document_chunk_ids
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentInsertionRecord
@@ -40,12 +41,12 @@ from onyx.document_index.vespa.chunk_retrieval import (
)
from onyx.document_index.vespa.chunk_retrieval import query_vespa
from onyx.document_index.vespa.deletion import delete_vespa_chunks
from onyx.document_index.vespa.indexing_utils import BaseHTTPXClientContext
from onyx.document_index.vespa.indexing_utils import batch_index_vespa_chunks
from onyx.document_index.vespa.indexing_utils import check_for_final_chunk_existence
from onyx.document_index.vespa.indexing_utils import clean_chunk_id_copy
from onyx.document_index.vespa.indexing_utils import GlobalHTTPXClientContext
from onyx.document_index.vespa.indexing_utils import TemporaryHTTPXClientContext
from onyx.document_index.vespa.indexing_utils import (
get_multipass_config,
)
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
@@ -131,34 +132,12 @@ class VespaIndex(DocumentIndex):
self,
index_name: str,
secondary_index_name: str | None,
large_chunks_enabled: bool,
secondary_large_chunks_enabled: bool | None,
multitenant: bool = False,
httpx_client: httpx.Client | None = None,
) -> None:
self.index_name = index_name
self.secondary_index_name = secondary_index_name
self.large_chunks_enabled = large_chunks_enabled
self.secondary_large_chunks_enabled = secondary_large_chunks_enabled
self.multitenant = multitenant
self.httpx_client_context: BaseHTTPXClientContext
if httpx_client:
self.httpx_client_context = GlobalHTTPXClientContext(httpx_client)
else:
self.httpx_client_context = TemporaryHTTPXClientContext(
get_vespa_http_client
)
self.index_to_large_chunks_enabled: dict[str, bool] = {}
self.index_to_large_chunks_enabled[index_name] = large_chunks_enabled
if secondary_index_name and secondary_large_chunks_enabled:
self.index_to_large_chunks_enabled[
secondary_index_name
] = secondary_large_chunks_enabled
self.http_client = get_vespa_http_client()
def ensure_indices_exist(
self,
@@ -352,7 +331,7 @@ class VespaIndex(DocumentIndex):
# indexing / updates / deletes since we have to make a large volume of requests.
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
self.httpx_client_context as http_client,
get_vespa_http_client() as http_client,
):
# We require the start and end index for each document in order to
# know precisely which chunks to delete. This information exists for
@@ -411,11 +390,9 @@ class VespaIndex(DocumentIndex):
for doc_id in all_doc_ids
}
@classmethod
@staticmethod
def _apply_updates_batched(
cls,
updates: list[_VespaUpdateRequest],
httpx_client: httpx.Client,
batch_size: int = BATCH_SIZE,
) -> None:
"""Runs a batch of updates in parallel via the ThreadPoolExecutor."""
@@ -437,7 +414,7 @@ class VespaIndex(DocumentIndex):
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
httpx_client as http_client,
get_vespa_http_client() as http_client,
):
for update_batch in batch_generator(updates, batch_size):
future_to_document_id = {
@@ -478,7 +455,7 @@ class VespaIndex(DocumentIndex):
index_names.append(self.secondary_index_name)
chunk_id_start_time = time.monotonic()
with self.httpx_client_context as http_client:
with get_vespa_http_client() as http_client:
for update_request in update_requests:
for doc_info in update_request.minimal_document_indexing_info:
for index_name in index_names:
@@ -534,8 +511,7 @@ class VespaIndex(DocumentIndex):
)
)
with self.httpx_client_context as httpx_client:
self._apply_updates_batched(processed_updates_requests, httpx_client)
self._apply_updates_batched(processed_updates_requests)
logger.debug(
"Finished updating Vespa documents in %.2f seconds",
time.monotonic() - update_start,
@@ -547,7 +523,6 @@ class VespaIndex(DocumentIndex):
index_name: str,
fields: VespaDocumentFields,
doc_id: str,
http_client: httpx.Client,
) -> None:
"""
Update a single "chunk" (document) in Vespa using its chunk ID.
@@ -579,17 +554,18 @@ class VespaIndex(DocumentIndex):
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}?create=true"
try:
resp = http_client.put(
vespa_url,
headers={"Content-Type": "application/json"},
json=update_dict,
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
error_message = f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). Details: {e.response.text}"
logger.error(error_message)
raise
with get_vespa_http_client(http2=False) as http_client:
try:
resp = http_client.put(
vespa_url,
headers={"Content-Type": "application/json"},
json=update_dict,
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
error_message = f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). Details: {e.response.text}"
logger.error(error_message)
raise
def update_single(
self,
@@ -603,16 +579,24 @@ class VespaIndex(DocumentIndex):
function will complete with no errors or exceptions.
Handle other exceptions if you wish to implement retry behavior
"""
doc_chunk_count = 0
with self.httpx_client_context as httpx_client:
for (
index_name,
large_chunks_enabled,
) in self.index_to_large_chunks_enabled.items():
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with get_vespa_http_client(http2=False) as http_client:
for index_name in index_names:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
multipass_config = get_multipass_config(
db_session=db_session,
primary_index=index_name == self.index_name,
)
large_chunks_enabled = multipass_config.enable_large_chunks
enriched_doc_infos = VespaIndex.enrich_basic_chunk_info(
index_name=index_name,
http_client=httpx_client,
http_client=http_client,
document_id=doc_id,
previous_chunk_count=chunk_count,
new_chunk_count=0,
@@ -628,7 +612,10 @@ class VespaIndex(DocumentIndex):
for doc_chunk_id in doc_chunk_ids:
self.update_single_chunk(
doc_chunk_id, index_name, fields, doc_id, httpx_client
doc_chunk_id=doc_chunk_id,
index_name=index_name,
fields=fields,
doc_id=doc_id,
)
return doc_chunk_count
@@ -650,13 +637,19 @@ class VespaIndex(DocumentIndex):
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with self.httpx_client_context as http_client, concurrent.futures.ThreadPoolExecutor(
with get_vespa_http_client(
http2=False
) as http_client, concurrent.futures.ThreadPoolExecutor(
max_workers=NUM_THREADS
) as executor:
for (
index_name,
large_chunks_enabled,
) in self.index_to_large_chunks_enabled.items():
for index_name in index_names:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
multipass_config = get_multipass_config(
db_session=db_session,
primary_index=index_name == self.index_name,
)
large_chunks_enabled = multipass_config.enable_large_chunks
enriched_doc_infos = VespaIndex.enrich_basic_chunk_info(
index_name=index_name,
http_client=http_client,
@@ -825,9 +818,6 @@ class VespaIndex(DocumentIndex):
"""
Deletes all entries in the specified index with the given tenant_id.
Currently unused, but we anticipate this being useful. The entire flow does not
use the httpx connection pool of an instance.
Parameters:
tenant_id (str): The tenant ID whose documents are to be deleted.
index_name (str): The name of the index from which to delete documents.
@@ -860,8 +850,6 @@ class VespaIndex(DocumentIndex):
"""
Retrieves all document IDs with the specified tenant_id, handling pagination.
Internal helper function for delete_entries_by_tenant_id.
Parameters:
tenant_id (str): The tenant ID to search for.
index_name (str): The name of the index to search in.
@@ -894,8 +882,8 @@ class VespaIndex(DocumentIndex):
f"Querying for document IDs with tenant_id: {tenant_id}, offset: {offset}"
)
with get_vespa_http_client() as http_client:
response = http_client.get(url, params=query_params, timeout=None)
with get_vespa_http_client(no_timeout=True) as http_client:
response = http_client.get(url, params=query_params)
response.raise_for_status()
search_result = response.json()
@@ -925,11 +913,6 @@ class VespaIndex(DocumentIndex):
"""
Deletes documents in batches using multiple threads.
Internal helper function for delete_entries_by_tenant_id.
This is a class method and does not use the httpx pool of the instance.
This is OK because we don't use this method often.
Parameters:
delete_requests (List[_VespaDeleteRequest]): The list of delete requests.
batch_size (int): The number of documents to delete in each batch.
@@ -942,14 +925,13 @@ class VespaIndex(DocumentIndex):
response = http_client.delete(
delete_request.url,
headers={"Content-Type": "application/json"},
timeout=None,
)
response.raise_for_status()
logger.debug(f"Starting batch deletion for {len(delete_requests)} documents")
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
with get_vespa_http_client() as http_client:
with get_vespa_http_client(no_timeout=True) as http_client:
for batch_start in range(0, len(delete_requests), batch_size):
batch = delete_requests[batch_start : batch_start + batch_size]

View File

@@ -1,19 +1,21 @@
import concurrent.futures
import json
import uuid
from abc import ABC
from abc import abstractmethod
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
import httpx
from retry import retry
from sqlalchemy.orm import Session
from onyx.configs.app_configs import ENABLE_MULTIPASS_INDEXING
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
)
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.document_index.document_index_utils import get_uuid_from_chunk
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info_old
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
@@ -48,9 +50,10 @@ from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import TITLE
from onyx.document_index.vespa_constants import TITLE_EMBEDDING
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import EmbeddingProvider
from onyx.indexing.models import MultipassConfig
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -272,42 +275,46 @@ def check_for_final_chunk_existence(
index += 1
class BaseHTTPXClientContext(ABC):
"""Abstract base class for an HTTPX client context manager."""
@abstractmethod
def __enter__(self) -> httpx.Client:
pass
@abstractmethod
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
pass
def should_use_multipass(search_settings: SearchSettings | None) -> bool:
"""
Determines whether multipass should be used based on the search settings
or the default config if settings are unavailable.
"""
if search_settings is not None:
return search_settings.multipass_indexing
return ENABLE_MULTIPASS_INDEXING
class GlobalHTTPXClientContext(BaseHTTPXClientContext):
"""Context manager for a global HTTPX client that does not close it."""
def __init__(self, client: httpx.Client):
self._client = client
def __enter__(self) -> httpx.Client:
return self._client # Reuse the global client
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
pass # Do nothing; don't close the global client
def can_use_large_chunks(multipass: bool, search_settings: SearchSettings) -> bool:
"""
Given multipass usage and an embedder, decides whether large chunks are allowed
based on model/provider constraints.
"""
# Only local models that support a larger context are from Nomic
# Cohere does not support larger contexts (they recommend not going above ~512 tokens)
return (
multipass
and search_settings.model_name.startswith("nomic-ai")
and search_settings.provider_type != EmbeddingProvider.COHERE
)
class TemporaryHTTPXClientContext(BaseHTTPXClientContext):
"""Context manager for a temporary HTTPX client that closes it after use."""
def __init__(self, client_factory: Callable[[], httpx.Client]):
self._client_factory = client_factory
self._client: httpx.Client | None = None # Client will be created in __enter__
def __enter__(self) -> httpx.Client:
self._client = self._client_factory() # Create a new client
return self._client
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
if self._client:
self._client.close()
def get_multipass_config(
db_session: Session, primary_index: bool = True
) -> MultipassConfig:
"""
Determines whether to enable multipass and large chunks by examining
the current search settings and the embedder configuration.
"""
search_settings = (
get_current_search_settings(db_session)
if primary_index
else get_secondary_search_settings(db_session)
)
multipass = should_use_multipass(search_settings)
if not search_settings:
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
enable_large_chunks = can_use_large_chunks(multipass, search_settings)
return MultipassConfig(
multipass_indexing=multipass, enable_large_chunks=enable_large_chunks
)

View File

@@ -55,7 +55,7 @@ def remove_invalid_unicode_chars(text: str) -> str:
"""Vespa does not take in unicode chars that aren't valid for XML.
This removes them."""
_illegal_xml_chars_RE: re.Pattern = re.compile(
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFDD0-\uFDEF\uFFFE\uFFFF]"
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFFFE\uFFFF]"
)
return _illegal_xml_chars_RE.sub("", text)

View File

@@ -358,13 +358,7 @@ def extract_file_text(
try:
if get_unstructured_api_key():
try:
return unstructured_to_text(file, file_name)
except Exception as unstructured_error:
logger.error(
f"Failed to process with Unstructured: {str(unstructured_error)}. Falling back to normal processing."
)
# Fall through to normal processing
return unstructured_to_text(file, file_name)
if file_name or extension:
if extension is not None:

View File

@@ -52,7 +52,7 @@ def _sdk_partition_request(
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
logger.debug(f"Starting to read file: {file_name}")
req = _sdk_partition_request(file, file_name, strategy="fast")
req = _sdk_partition_request(file, file_name, strategy="auto")
unstructured_client = UnstructuredClient(api_key_auth=get_unstructured_api_key())

View File

@@ -1,57 +0,0 @@
import threading
from typing import Any
import httpx
class HttpxPool:
"""Class to manage a global httpx Client instance"""
_clients: dict[str, httpx.Client] = {}
_lock: threading.Lock = threading.Lock()
# Default parameters for creation
DEFAULT_KWARGS = {
"http2": True,
"limits": lambda: httpx.Limits(),
}
def __init__(self) -> None:
pass
@classmethod
def _init_client(cls, **kwargs: Any) -> httpx.Client:
"""Private helper method to create and return an httpx.Client."""
merged_kwargs = {**cls.DEFAULT_KWARGS, **kwargs}
return httpx.Client(**merged_kwargs)
@classmethod
def init_client(cls, name: str, **kwargs: Any) -> None:
"""Allow the caller to init the client with extra params."""
with cls._lock:
if name not in cls._clients:
cls._clients[name] = cls._init_client(**kwargs)
@classmethod
def close_client(cls, name: str) -> None:
"""Allow the caller to close the client."""
with cls._lock:
client = cls._clients.pop(name, None)
if client:
client.close()
@classmethod
def close_all(cls) -> None:
"""Close all registered clients."""
with cls._lock:
for client in cls._clients.values():
client.close()
cls._clients.clear()
@classmethod
def get(cls, name: str) -> httpx.Client:
"""Gets the httpx.Client. Will init to default settings if not init'd."""
with cls._lock:
if name not in cls._clients:
cls._clients[name] = cls._init_client()
return cls._clients[name]

View File

@@ -31,15 +31,14 @@ from onyx.db.document import upsert_documents
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.models import Document as DBDocument
from onyx.db.search_settings import get_current_search_settings
from onyx.db.tag import create_or_add_document_tag
from onyx.db.tag import create_or_add_document_tag_list
from onyx.document_index.document_index_utils import (
get_multipass_config,
)
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.document_index.vespa.indexing_utils import (
get_multipass_config,
)
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -358,6 +357,7 @@ def index_doc_batch(
is_public=False,
)
logger.debug("Filtering Documents")
filtered_documents = filter_fnc(document_batch)
ctx = index_doc_batch_prepare(
@@ -380,15 +380,6 @@ def index_doc_batch(
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
)
doc_descriptors = [
{
"doc_id": doc.id,
"doc_length": doc.get_total_char_length(),
}
for doc in ctx.updatable_docs
]
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")
logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
@@ -536,8 +527,7 @@ def build_indexing_pipeline(
callback: IndexingHeartbeatInterface | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
search_settings = get_current_search_settings(db_session)
multipass_config = get_multipass_config(search_settings)
multipass_config = get_multipass_config(db_session, primary_index=True)
chunker = chunker or Chunker(
tokenizer=embedder.embedding_model.tokenizer,

View File

@@ -55,7 +55,9 @@ class DocAwareChunk(BaseChunk):
def to_short_descriptor(self) -> str:
"""Used when logging the identity of a chunk"""
return f"{self.source_document.to_short_descriptor()} Chunk ID: {self.chunk_id}"
return (
f"Chunk ID: '{self.chunk_id}'; {self.source_document.to_short_descriptor()}"
)
class IndexChunk(DocAwareChunk):

View File

@@ -2,7 +2,7 @@ from onyx.key_value_store.interface import KeyValueStore
from onyx.key_value_store.store import PgRedisKVStore
def get_kv_store() -> KeyValueStore:
def get_kv_store(tenant_id: str | None = None) -> KeyValueStore:
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
# It's read from the global thread level variable
return PgRedisKVStore()
return PgRedisKVStore(tenant_id=tenant_id)

View File

@@ -18,7 +18,7 @@ from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import get_current_tenant_id
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -28,8 +28,10 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(self, redis_client: Redis | None = None) -> None:
self.tenant_id = get_current_tenant_id()
def __init__(
self, redis_client: Redis | None = None, tenant_id: str | None = None
) -> None:
self.tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
# If no redis_client is provided, fall back to the context var
if redis_client is not None:

View File

@@ -26,7 +26,6 @@ from langchain_core.messages.tool import ToolMessage
from langchain_core.prompt_values import PromptValue
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
from onyx.configs.model_configs import (
DISABLE_LITELLM_STREAMING,
)
@@ -388,7 +387,6 @@ class DefaultMultiLLM(LLM):
try:
return litellm.completion(
mock_response=MOCK_LLM_RESPONSE,
# model choice
model=f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
# NOTE: have to pass in None instead of empty string for these

View File

@@ -109,9 +109,7 @@ from onyx.utils.variable_functionality import global_version
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CORS_ALLOWED_ORIGIN
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SENTRY_DSN
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -214,8 +212,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
if not MULTI_TENANT:
# We cache this at the beginning so there is no delay in the first telemetry
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
get_or_generate_uuid()
get_or_generate_uuid(tenant_id=None)
# If we are multi-tenant, we need to only set up initial public tables
with Session(engine) as db_session:

View File

@@ -1,8 +1,6 @@
import threading
import time
from collections.abc import Callable
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from typing import Any
@@ -13,7 +11,6 @@ from requests import RequestException
from requests import Response
from retry import retry
from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
from onyx.configs.app_configs import SKIP_WARM_UP
from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
@@ -158,7 +155,6 @@ class EmbeddingModel:
text_type: EmbedTextType,
batch_size: int,
max_seq_length: int,
num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS,
) -> list[Embedding]:
text_batches = batch_list(texts, batch_size)
@@ -167,14 +163,12 @@ class EmbeddingModel:
)
embeddings: list[Embedding] = []
def process_batch(
batch_idx: int, text_batch: list[str]
) -> tuple[int, list[Embedding]]:
for idx, text_batch in enumerate(text_batches, start=1):
if self.callback:
if self.callback.should_stop():
raise RuntimeError("_batch_encode_texts detected stop signal")
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
@@ -190,52 +184,11 @@ class EmbeddingModel:
api_url=self.api_url,
)
start_time = time.time()
response = self._make_model_server_request(embed_request)
end_time = time.time()
processing_time = end_time - start_time
logger.info(
f"Batch {batch_idx} processing time: {processing_time:.2f} seconds"
)
return batch_idx, response.embeddings
# only multi thread if:
# 1. num_threads is greater than 1
# 2. we are using an API-based embedding model (provider_type is not None)
# 3. there are more than 1 batch (no point in threading if only 1)
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
with ThreadPoolExecutor(max_workers=num_threads) as executor:
future_to_batch = {
executor.submit(process_batch, idx, batch): idx
for idx, batch in enumerate(text_batches, start=1)
}
# Collect results in order
batch_results: list[tuple[int, list[Embedding]]] = []
for future in as_completed(future_to_batch):
try:
result = future.result()
batch_results.append(result)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
except Exception as e:
logger.exception("Embedding model failed to process batch")
raise e
# Sort by batch index and extend embeddings
batch_results.sort(key=lambda x: x[0])
for _, batch_embeddings in batch_results:
embeddings.extend(batch_embeddings)
else:
# Original sequential processing
for idx, text_batch in enumerate(text_batches, start=1):
_, batch_embeddings = process_batch(idx, text_batch)
embeddings.extend(batch_embeddings)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
embeddings.extend(response.embeddings)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
return embeddings
def encode(

View File

@@ -537,36 +537,30 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
# Let the tag flow handle this case, don't reply twice
return False
# Check if this is a bot message (either via bot_profile or bot_message subtype)
is_bot_message = bool(
event.get("bot_profile") or event.get("subtype") == "bot_message"
)
if is_bot_message:
if event.get("bot_profile"):
channel_name, _ = get_channel_name_from_id(
client=client.web_client, channel_id=channel
)
with get_session_with_tenant(client.tenant_id) as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
)
# If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
if (not bot_tag_id or bot_tag_id not in msg) and (
not slack_channel_config
or not slack_channel_config.channel_config.get("respond_to_bots")
):
channel_specific_logger.info(
"Ignoring message from bot since respond_to_bots is disabled"
)
channel_specific_logger.info("Ignoring message from bot")
return False
# Ignore things like channel_join, channel_leave, etc.
# NOTE: "file_share" is just a message with a file attachment, so we
# should not ignore it
message_subtype = event.get("subtype")
if message_subtype not in [None, "file_share", "bot_message"]:
if message_subtype not in [None, "file_share"]:
channel_specific_logger.info(
f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
)

View File

@@ -17,8 +17,6 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorPermissionSyncPayload(BaseModel):
id: str
submitted: datetime
started: datetime | None
celery_task_id: str | None
@@ -43,12 +41,6 @@ class RedisConnectorPermissionSync:
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
@@ -62,7 +54,6 @@ class RedisConnectorPermissionSync:
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
@@ -116,20 +107,6 @@ class RedisConnectorPermissionSync:
self.redis.set(self.fence_key, payload.model_dump_json())
def set_active(self) -> None:
"""This sets a signal to keep the permissioning flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
@property
def generator_complete(self) -> int | None:
"""the fence payload is an int representing the starting number of
@@ -196,7 +173,6 @@ class RedisConnectorPermissionSync:
return len(async_results)
def reset(self) -> None:
self.redis.delete(self.active_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
@@ -211,9 +187,6 @@ class RedisConnectorPermissionSync:
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorPermissionSync.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"):
r.delete(key)

View File

@@ -11,7 +11,6 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorExternalGroupSyncPayload(BaseModel):
submitted: datetime
started: datetime | None
celery_task_id: str | None
@@ -136,12 +135,6 @@ class RedisConnectorExternalGroupSync:
) -> int | None:
pass
def reset(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}"

View File

@@ -33,8 +33,8 @@ class RedisConnectorIndex:
TERMINATE_TTL = 600
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
# there are gaps in time between states where we need some slack
# to correctly transition
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600

View File

@@ -92,7 +92,7 @@ class RedisConnectorPrune:
if fence_bytes is None:
return None
fence_int = int(cast(bytes, fence_bytes))
fence_int = cast(int, fence_bytes)
return fence_int
@generator_complete.setter

View File

@@ -21,7 +21,6 @@ from onyx.configs.app_configs import REDIS_HOST
from onyx.configs.app_configs import REDIS_PASSWORD
from onyx.configs.app_configs import REDIS_POOL_MAX_CONNECTIONS
from onyx.configs.app_configs import REDIS_PORT
from onyx.configs.app_configs import REDIS_REPLICA_HOST
from onyx.configs.app_configs import REDIS_SSL
from onyx.configs.app_configs import REDIS_SSL_CA_CERTS
from onyx.configs.app_configs import REDIS_SSL_CERT_REQS
@@ -122,7 +121,7 @@ class TenantRedis(redis.Redis):
"ttl",
] # Regular methods that need simple prefixing
if item == "scan_iter" or item == "sscan_iter":
if item == "scan_iter":
return self._prefix_scan_iter(original_attr)
elif item in methods_to_wrap and callable(original_attr):
return self._prefix_method(original_attr)
@@ -133,32 +132,23 @@ class RedisPool:
_instance: Optional["RedisPool"] = None
_lock: threading.Lock = threading.Lock()
_pool: redis.BlockingConnectionPool
_replica_pool: redis.BlockingConnectionPool
def __new__(cls) -> "RedisPool":
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super(RedisPool, cls).__new__(cls)
cls._instance._init_pools()
cls._instance._init_pool()
return cls._instance
def _init_pools(self) -> None:
def _init_pool(self) -> None:
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
self._replica_pool = RedisPool.create_pool(
host=REDIS_REPLICA_HOST, ssl=REDIS_SSL
)
def get_client(self, tenant_id: str | None) -> Redis:
if tenant_id is None:
tenant_id = "public"
return TenantRedis(tenant_id, connection_pool=self._pool)
def get_replica_client(self, tenant_id: str | None) -> Redis:
if tenant_id is None:
tenant_id = "public"
return TenantRedis(tenant_id, connection_pool=self._replica_pool)
@staticmethod
def create_pool(
host: str = REDIS_HOST,
@@ -222,10 +212,6 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)
def get_redis_replica_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_replica_client(tenant_id)
SSL_CERT_REQS_MAP = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,

View File

@@ -16,7 +16,7 @@ from onyx.context.search.preprocessing.access_filters import (
from onyx.db.document_set import get_document_sets_by_ids
from onyx.db.models import StarterMessageModel as StarterMessage
from onyx.db.models import User
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.llm.factory import get_default_llms
from onyx.prompts.starter_messages import format_persona_starter_message_prompt
@@ -34,11 +34,8 @@ def get_random_chunks_from_doc_sets(
"""
Retrieves random chunks from the specified document sets.
"""
active_search_settings = get_active_search_settings(db_session)
document_index = get_default_document_index(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(curr_ind_name, sec_ind_name)
acl_filters = build_access_filters_for_user(user, db_session)
filters = IndexFilters(document_set=doc_sets, access_control_list=acl_filters)

View File

@@ -6184,7 +6184,7 @@
"chunk_ind": 0
},
{
"url": "https://docs.onyx.app/more/use_cases/support",
"url": "https://docs.onyx.app/more/use_cases/customer_support",
"title": "Customer Support",
"content": "Help your customer support team instantly answer any question across your entire product.\n\nAI Enabled Support\nCustomer support agents have one of the highest breadth jobs. They field requests that cover the entire surface area of the product and need to help your users find success on extremely short timelines. Because they're not the same people who designed or built the system, they often lack the depth of understanding needed - resulting in delays and escalations to other teams. Modern teams are leveraging AI to help their CS team optimize the speed and quality of these critical customer-facing interactions.\n\nThe Importance of Context\nThere are two critical components of AI copilots for customer support. The first is that the AI system needs to be connected with as much information as possible (not just support tools like Zendesk or Intercom) and that the knowledge needs to be as fresh as possible. Sometimes a fix might even be in places rarely checked by CS such as pull requests in a code repository. The second critical component is the ability of the AI system to break down difficult concepts and convoluted processes into more digestible descriptions and for your team members to be able to chat back and forth with the system to build a better understanding.\n\nOnyx takes care of both of these. The system connects up to over 30+ different applications and the knowledge is pulled in constantly so that the information access is always up to date.",
"title_embedding": [

View File

@@ -24,7 +24,7 @@
"chunk_ind": 0
},
{
"url": "https://docs.onyx.app/more/use_cases/support",
"url": "https://docs.onyx.app/more/use_cases/customer_support",
"title": "Customer Support",
"content": "Help your customer support team instantly answer any question across your entire product.\n\nAI Enabled Support\nCustomer support agents have one of the highest breadth jobs. They field requests that cover the entire surface area of the product and need to help your users find success on extremely short timelines. Because they're not the same people who designed or built the system, they often lack the depth of understanding needed - resulting in delays and escalations to other teams. Modern teams are leveraging AI to help their CS team optimize the speed and quality of these critical customer-facing interactions.\n\nThe Importance of Context\nThere are two critical components of AI copilots for customer support. The first is that the AI system needs to be connected with as much information as possible (not just support tools like Zendesk or Intercom) and that the knowledge needs to be as fresh as possible. Sometimes a fix might even be in places rarely checked by CS such as pull requests in a code repository. The second critical component is the ability of the AI system to break down difficult concepts and convoluted processes into more digestible descriptions and for your team members to be able to chat back and forth with the system to build a better understanding.\n\nOnyx takes care of both of these. The system connects up to over 30+ different applications and the knowledge is pulled in constantly so that the information access is always up to date.",
"chunk_ind": 0

View File

@@ -3,7 +3,6 @@ import json
import os
from typing import cast
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.access.models import default_public_access
@@ -24,7 +23,6 @@ from onyx.db.document import check_docs_exist
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import mock_successful_index_attempt
from onyx.db.models import Document as DbDocument
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import IndexBatchParams
@@ -61,7 +59,6 @@ def _create_indexable_chunks(
doc_updated_at=None,
primary_owners=[],
secondary_owners=[],
chunk_count=1,
)
if preprocessed_doc["chunk_ind"] == 0:
ids_to_documents[document.id] = document
@@ -158,7 +155,9 @@ def seed_initial_documents(
logger.info("Embedding model has been updated, skipping")
return
document_index = get_default_document_index(search_settings, None)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
)
# Create a connector so the user can delete it if they want
# or reindex it with a new search model if they want
@@ -241,12 +240,4 @@ def seed_initial_documents(
db_session=db_session,
)
# Since we bypass the indexing flow, we need to manually update the chunk count
for doc in docs:
db_session.execute(
update(DbDocument)
.where(DbDocument.id == doc.id)
.values(chunk_count=doc.chunk_count)
)
kv_store.store(KV_DOCUMENTS_SEEDED_KEY, True)

View File

@@ -15,9 +15,6 @@ from onyx.background.celery.celery_utils import get_deletion_attempt_snapshot
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
try_creating_permissions_sync_task,
)
from onyx.background.celery.tasks.external_group_syncing.tasks import (
try_creating_external_group_sync_task,
)
from onyx.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
@@ -42,7 +39,7 @@ from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.db.search_settings import get_active_search_settings_list
from onyx.db.search_settings import get_active_search_settings
from onyx.db.search_settings import get_current_search_settings
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_pool import get_redis_client
@@ -192,7 +189,7 @@ def update_cc_pair_status(
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
redis_connector.stop.set_fence(True)
search_settings_list: list[SearchSettings] = get_active_search_settings_list(
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
@@ -422,101 +419,27 @@ def sync_cc_pair(
if redis_connector.permissions.fenced:
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Permissions sync task already in progress.",
detail="Doc permissions sync task already in progress.",
)
logger.info(
f"Permissions sync cc_pair={cc_pair_id} "
f"Doc permissions sync cc_pair={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_permissions_sync_task(
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
)
if not payload_id:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Permissions sync task creation failed.",
)
logger.info(f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}")
return StatusResponse(
success=True,
message="Successfully created the permissions sync task.",
)
@router.get("/admin/cc-pair/{cc_pair_id}/sync-groups")
def get_cc_pair_latest_group_sync(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> datetime | None:
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="cc_pair not found for current user's permissions",
)
return cc_pair.last_time_external_group_sync
@router.post("/admin/cc-pair/{cc_pair_id}/sync-groups")
def sync_cc_pair_groups(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]:
"""Triggers group sync on a particular cc_pair immediately"""
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
r = get_redis_client(tenant_id=tenant_id)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if redis_connector.external_group_sync.fenced:
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="External group sync task already in progress.",
)
logger.info(
f"External group sync cc_pair={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
tasks_created = try_creating_external_group_sync_task(
tasks_created = try_creating_permissions_sync_task(
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
)
if not tasks_created:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="External group sync task creation failed.",
detail="Doc permissions sync task creation failed.",
)
return StatusResponse(
success=True,
message="Successfully created the external group sync task.",
message="Successfully created the doc permissions sync task.",
)

View File

@@ -32,7 +32,10 @@ def get_document_info(
db_session: Session = Depends(get_session),
) -> DocumentInfo:
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(search_settings, None)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
)
user_acl_filters = build_access_filters_for_user(user, db_session)
inference_chunks = document_index.id_based_retrieval(
@@ -76,7 +79,10 @@ def get_chunk_info(
db_session: Session = Depends(get_session),
) -> ChunkInfo:
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(search_settings, None)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
)
user_acl_filters = build_access_filters_for_user(user, db_session)
chunk_request = VespaChunkRequest(

View File

@@ -357,7 +357,6 @@ class ConnectorCredentialPairDescriptor(BaseModel):
name: str | None = None
connector: ConnectorSnapshot
credential: CredentialSnapshot
access_type: AccessType
class RunConnectorRequest(BaseModel):

View File

@@ -68,7 +68,6 @@ class DocumentSet(BaseModel):
credential=CredentialSnapshot.from_credential_db_model(
cc_pair.credential
),
access_type=cc_pair.access_type,
)
for cc_pair in document_set_model.connector_credential_pairs
],

View File

@@ -10,7 +10,6 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_for_user
from onyx.db.llm import fetch_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
@@ -196,7 +195,5 @@ def list_llm_provider_basics(
) -> list[LLMProviderDescriptor]:
return [
LLMProviderDescriptor.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers_for_user(
db_session, user
)
for llm_provider_model in fetch_existing_llm_providers(db_session, user)
]

View File

@@ -44,11 +44,11 @@ class UserPreferences(BaseModel):
chosen_assistants: list[int] | None = None
hidden_assistants: list[int] = []
visible_assistants: list[int] = []
recent_assistants: list[int] | None = None
default_model: str | None = None
auto_scroll: bool | None = None
pinned_assistants: list[int] | None = None
shortcut_enabled: bool | None = None
temperature_override_enabled: bool | None = None
class UserInfo(BaseModel):
@@ -92,7 +92,6 @@ class UserInfo(BaseModel):
hidden_assistants=user.hidden_assistants,
pinned_assistants=user.pinned_assistants,
visible_assistants=user.visible_assistants,
temperature_override_enabled=user.temperature_override_enabled,
)
),
organization_name=organization_name,

View File

@@ -22,7 +22,6 @@ from onyx.db.search_settings import get_embedding_provider_from_provider_type
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_current_search_settings
from onyx.db.search_settings import update_search_settings_status
from onyx.document_index.document_index_utils import get_multipass_config
from onyx.document_index.factory import get_default_document_index
from onyx.file_processing.unstructured import delete_unstructured_api_key
from onyx.file_processing.unstructured import get_unstructured_api_key
@@ -98,9 +97,10 @@ def set_new_search_settings(
)
# Ensure Vespa has the new index immediately
get_multipass_config(search_settings)
get_multipass_config(new_search_settings)
document_index = get_default_document_index(search_settings, new_search_settings)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=new_search_settings.index_name,
)
document_index.ensure_indices_exist(
index_embedding_dim=search_settings.model_dim,

View File

@@ -568,9 +568,33 @@ def verify_user_logged_in(
"""APIs to adjust user preferences"""
@router.patch("/temperature-override-enabled")
def update_user_temperature_override_enabled(
temperature_override_enabled: bool,
class ChosenDefaultModelRequest(BaseModel):
default_model: str | None = None
class RecentAssistantsRequest(BaseModel):
current_assistant: int
def update_recent_assistants(
recent_assistants: list[int] | None, current_assistant: int
) -> list[int]:
if recent_assistants is None:
recent_assistants = []
else:
recent_assistants = [x for x in recent_assistants if x != current_assistant]
# Add current assistant to start of list
recent_assistants.insert(0, current_assistant)
# Keep only the 5 most recent assistants
recent_assistants = recent_assistants[:5]
return recent_assistants
@router.patch("/user/recent-assistants")
def update_user_recent_assistants(
request: RecentAssistantsRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
@@ -578,26 +602,29 @@ def update_user_temperature_override_enabled(
if AUTH_TYPE == AuthType.DISABLED:
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
no_auth_user.preferences.temperature_override_enabled = (
temperature_override_enabled
preferences = no_auth_user.preferences
recent_assistants = preferences.recent_assistants
updated_preferences = update_recent_assistants(
recent_assistants, request.current_assistant
)
set_no_auth_user_preferences(store, no_auth_user.preferences)
preferences.recent_assistants = updated_preferences
set_no_auth_user_preferences(store, preferences)
return
else:
raise RuntimeError("This should never happen")
recent_assistants = UserInfo.from_model(user).preferences.recent_assistants
updated_recent_assistants = update_recent_assistants(
recent_assistants, request.current_assistant
)
db_session.execute(
update(User)
.where(User.id == user.id) # type: ignore
.values(temperature_override_enabled=temperature_override_enabled)
.values(recent_assistants=updated_recent_assistants)
)
db_session.commit()
class ChosenDefaultModelRequest(BaseModel):
default_model: str | None = None
@router.patch("/shortcut-enabled")
def update_user_shortcut_enabled(
shortcut_enabled: bool,
@@ -704,6 +731,30 @@ class ChosenAssistantsRequest(BaseModel):
chosen_assistants: list[int]
@router.patch("/user/assistant-list")
def update_user_assistant_list(
request: ChosenAssistantsRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
no_auth_user.preferences.chosen_assistants = request.chosen_assistants
set_no_auth_user_preferences(store, no_auth_user.preferences)
return
else:
raise RuntimeError("This should never happen")
db_session.execute(
update(User)
.where(User.id == user.id) # type: ignore
.values(chosen_assistants=request.chosen_assistants)
)
db_session.commit()
def update_assistant_visibility(
preferences: UserPreferences, assistant_id: int, show: bool
) -> UserPreferences:

View File

@@ -14,9 +14,9 @@ from onyx.db.document import get_ingestion_documents
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.db.search_settings import get_active_search_settings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
@@ -89,10 +89,9 @@ def upsert_ingestion_doc(
)
# Need to index for both the primary and secondary index if possible
active_search_settings = get_active_search_settings(db_session)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
curr_doc_index = get_default_document_index(
active_search_settings.primary,
None,
primary_index_name=curr_ind_name, secondary_index_name=None
)
search_settings = get_current_search_settings(db_session)
@@ -118,7 +117,11 @@ def upsert_ingestion_doc(
)
# If there's a secondary index being built, index the doc but don't use it for return here
if active_search_settings.secondary:
if sec_ind_name:
sec_doc_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=None
)
sec_search_settings = get_secondary_search_settings(db_session)
if sec_search_settings is None:
@@ -131,10 +134,6 @@ def upsert_ingestion_doc(
search_settings=sec_search_settings
)
sec_doc_index = get_default_document_index(
active_search_settings.secondary, None
)
sec_ind_pipeline = build_indexing_pipeline(
embedder=new_index_embedding_model,
document_index=sec_doc_index,

View File

@@ -18,6 +18,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import extract_headers
@@ -77,7 +78,6 @@ from onyx.server.query_and_chat.models import LLMOverride
from onyx.server.query_and_chat.models import PromptOverride
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SearchFeedbackRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.token_limit import check_token_rate_limits
from onyx.utils.headers import get_custom_tool_additional_request_headers
@@ -115,52 +115,12 @@ def get_user_chat_sessions(
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
current_temperature_override=chat.temperature_override,
)
for chat in chat_sessions
]
)
@router.put("/update-chat-session-temperature")
def update_chat_session_temperature(
update_thread_req: UpdateChatSessionTemperatureRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
chat_session = get_chat_session_by_id(
chat_session_id=update_thread_req.chat_session_id,
user_id=user.id if user is not None else None,
db_session=db_session,
)
# Validate temperature_override
if update_thread_req.temperature_override is not None:
if (
update_thread_req.temperature_override < 0
or update_thread_req.temperature_override > 2
):
raise HTTPException(
status_code=400, detail="Temperature must be between 0 and 2"
)
# Additional check for Anthropic models
if (
chat_session.current_alternate_model
and "anthropic" in chat_session.current_alternate_model.lower()
):
if update_thread_req.temperature_override > 1:
raise HTTPException(
status_code=400,
detail="Temperature for Anthropic models must be between 0 and 1",
)
chat_session.temperature_override = update_thread_req.temperature_override
db_session.add(chat_session)
db_session.commit()
@router.put("/update-chat-session-model")
def update_chat_session_model(
update_thread_req: UpdateChatSessionThreadRequest,
@@ -231,7 +191,6 @@ def get_chat_session(
],
time_created=chat_session.time_created,
shared_status=chat_session.shared_status,
current_temperature_override=chat_session.temperature_override,
)
@@ -463,7 +422,7 @@ def set_message_as_latest(
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_limited_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user else None
@@ -713,25 +672,23 @@ def upload_files_for_chat(
else ChatFileType.PLAIN_TEXT
)
file_content = file.file.read() # Read the file content
if file_type == ChatFileType.IMAGE:
file_content_io = file.file
file_content = file.file
# NOTE: Image conversion to JPEG used to be enforced here.
# This was removed to:
# 1. Preserve original file content for downloads
# 2. Maintain transparency in formats like PNG
# 3. Ameliorate issue with file conversion
else:
file_content_io = io.BytesIO(file_content)
file_content = io.BytesIO(file.file.read())
new_content_type = file.content_type
# Store the file normally
# store the file (now JPEG for images)
file_id = str(uuid.uuid4())
file_store.save_file(
file_name=file_id,
content=file_content_io,
content=file_content,
display_name=file.filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=new_content_type or file_type.value,
@@ -741,7 +698,7 @@ def upload_files_for_chat(
# to re-extract it every time we send a message
if file_type == ChatFileType.DOC:
extracted_text = extract_file_text(
file=io.BytesIO(file_content), # use the bytes we already read
file=file.file,
file_name=file.filename or "",
)
text_file_id = str(uuid.uuid4())

View File

@@ -42,11 +42,6 @@ class UpdateChatSessionThreadRequest(BaseModel):
new_alternate_model: str
class UpdateChatSessionTemperatureRequest(BaseModel):
chat_session_id: UUID
temperature_override: float
class ChatSessionCreationRequest(BaseModel):
# If not specified, use Onyx default persona
persona_id: int = 0
@@ -113,10 +108,6 @@ class CreateChatMessageRequest(ChunkContext):
llm_override: LLMOverride | None = None
prompt_override: PromptOverride | None = None
# Allows the caller to override the temperature for the chat session
# this does persist in the chat thread details
temperature_override: float | None = None
# allow user to specify an alternate assistnat
alternate_assistant_id: int | None = None
@@ -177,7 +168,6 @@ class ChatSessionDetails(BaseModel):
shared_status: ChatSessionSharedStatus
folder_id: int | None = None
current_alternate_model: str | None = None
current_temperature_override: float | None = None
class ChatSessionsResponse(BaseModel):
@@ -241,7 +231,6 @@ class ChatSessionDetailResponse(BaseModel):
time_created: datetime
shared_status: ChatSessionSharedStatus
current_alternate_model: str | None
current_temperature_override: float | None
# This one is not used anymore

View File

@@ -64,8 +64,9 @@ def admin_search(
tenant_id=tenant_id,
)
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(search_settings, None)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
)
if not isinstance(document_index, VespaIndex):
raise HTTPException(
status_code=400,

View File

@@ -1,6 +1,4 @@
import base64
import json
import os
from datetime import datetime
from typing import Any
@@ -68,10 +66,3 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
)
return masked_creds
def make_short_id() -> str:
"""Fast way to generate a random 8 character id ... useful for tagging data
to trace it through a flow. This is definitely not guaranteed to be unique and is
targeted at the stated use case."""
return base64.b32encode(os.urandom(5)).decode("utf-8")[:8] # 5 bytes → 8 chars

View File

@@ -25,7 +25,6 @@ from onyx.db.llm import fetch_default_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.persona import delete_old_default_personas
from onyx.db.search_settings import get_active_search_settings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_current_search_settings
@@ -37,7 +36,6 @@ from onyx.document_index.vespa.index import VespaIndex
from onyx.indexing.models import IndexingSetting
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from onyx.natural_language_processing.search_nlp_models import warm_up_cross_encoder
@@ -72,19 +70,8 @@ def setup_onyx(
The Tenant Service calls the tenants/create endpoint which runs this.
"""
check_index_swap(db_session=db_session)
active_search_settings = get_active_search_settings(db_session)
search_settings = active_search_settings.primary
secondary_search_settings = active_search_settings.secondary
# search_settings = get_current_search_settings(db_session)
# multipass_config_1 = get_multipass_config(search_settings)
# secondary_large_chunks_enabled: bool | None = None
# secondary_search_settings = get_secondary_search_settings(db_session)
# if secondary_search_settings:
# multipass_config_2 = get_multipass_config(secondary_search_settings)
# secondary_large_chunks_enabled = multipass_config_2.enable_large_chunks
search_settings = get_current_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
# Break bad state for thrashing indexes
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
@@ -135,8 +122,10 @@ def setup_onyx(
# takes a bit of time to start up
logger.notice("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index(
search_settings,
secondary_search_settings,
primary_index_name=search_settings.index_name,
secondary_index_name=secondary_search_settings.index_name
if secondary_search_settings
else None,
)
success = setup_vespa(
@@ -280,7 +269,6 @@ def setup_postgres(db_session: Session) -> None:
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:
# Only for dev flows
logger.notice("Setting up default OpenAI LLM for dev.")
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini"
model_req = LLMProviderUpsertRequest(
@@ -294,8 +282,8 @@ def setup_postgres(db_session: Session) -> None:
fast_default_model_name=fast_model,
is_public=True,
groups=[],
display_model_names=OPEN_AI_MODEL_NAMES,
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=[llm_model, fast_model],
model_names=[llm_model, fast_model],
)
new_llm_provider = upsert_llm_provider(
llm_provider=model_req, db_session=db_session

View File

@@ -220,13 +220,6 @@ class InternetSearchTool(Tool):
)
results = response.json()
# If no hits, Bing does not include the webPages key
search_results = (
results["webPages"]["value"][: self.num_results]
if "webPages" in results
else []
)
return InternetSearchResponse(
revised_query=query,
internet_results=[
@@ -235,7 +228,7 @@ class InternetSearchTool(Tool):
link=result["url"],
snippet=result["snippet"],
)
for result in search_results
for result in results["webPages"]["value"][: self.num_results]
],
)

View File

@@ -26,13 +26,6 @@ doc_permission_sync_ctx: contextvars.ContextVar[
] = contextvars.ContextVar("doc_permission_sync_ctx", default=dict())
class LoggerContextVars:
@staticmethod
def reset() -> None:
pruning_ctx.set(dict())
doc_permission_sync_ctx.set(dict())
class TaskAttemptSingleton:
"""Used to tell if this process is an indexing job, and if so what is the
unique identifier for this indexing attempt. For things like the API server,
@@ -77,32 +70,27 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
) -> tuple[str, MutableMapping[str, Any]]:
# If this is an indexing job, add the attempt ID to the log message
# This helps filter the logs for this specific indexing
while True:
pruning_ctx_dict = pruning_ctx.get()
if len(pruning_ctx_dict) > 0:
if "request_id" in pruning_ctx_dict:
msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}"
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
if "cc_pair_id" in pruning_ctx_dict:
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
break
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
if len(doc_permission_sync_ctx_dict) > 0:
if "request_id" in doc_permission_sync_ctx_dict:
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
break
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
pruning_ctx_dict = pruning_ctx.get()
if len(pruning_ctx_dict) > 0:
if "request_id" in pruning_ctx_dict:
msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}"
if "cc_pair_id" in pruning_ctx_dict:
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
elif len(doc_permission_sync_ctx_dict) > 0:
if "request_id" in doc_permission_sync_ctx_dict:
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
else:
if index_attempt_id is not None:
msg = f"[Index Attempt: {index_attempt_id}] {msg}"
if cc_pair_id is not None:
msg = f"[CC Pair: {cc_pair_id}] {msg}"
break
# Add tenant information if it differs from default
# This will always be the case for authenticated API requests
if MULTI_TENANT:

View File

@@ -1,4 +1,3 @@
import contextvars
import threading
import uuid
from enum import Enum
@@ -42,7 +41,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str:
return str(uuid.uuid5(uuid.NAMESPACE_X500, tenant_id))
def get_or_generate_uuid() -> str:
def get_or_generate_uuid(tenant_id: str | None) -> str:
# TODO: split out the whole "instance UUID" generation logic into a separate
# utility function. Telemetry should not be aware at all of how the UUID is
# generated/stored.
@@ -53,7 +52,7 @@ def get_or_generate_uuid() -> str:
if _CACHED_UUID is not None:
return _CACHED_UUID
kv_store = get_kv_store()
kv_store = get_kv_store(tenant_id=tenant_id)
try:
_CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY))
@@ -64,18 +63,18 @@ def get_or_generate_uuid() -> str:
return _CACHED_UUID
def _get_or_generate_instance_domain() -> str | None: #
def _get_or_generate_instance_domain(tenant_id: str | None = None) -> str | None: #
global _CACHED_INSTANCE_DOMAIN
if _CACHED_INSTANCE_DOMAIN is not None:
return _CACHED_INSTANCE_DOMAIN
kv_store = get_kv_store()
kv_store = get_kv_store(tenant_id=tenant_id)
try:
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
except KvKeyNotFoundError:
with get_session_with_tenant() as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
first_user = db_session.query(User).first()
if first_user:
_CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1]
@@ -104,7 +103,7 @@ def optional_telemetry(
customer_uuid = (
_get_or_generate_customer_id_mt(tenant_id)
if MULTI_TENANT
else get_or_generate_uuid()
else get_or_generate_uuid(tenant_id)
)
payload = {
"data": data,
@@ -116,23 +115,20 @@ def optional_telemetry(
"is_cloud": MULTI_TENANT,
}
if ENTERPRISE_EDITION_ENABLED:
payload["instance_domain"] = _get_or_generate_instance_domain()
payload["instance_domain"] = _get_or_generate_instance_domain(
tenant_id
)
requests.post(
_DANSWER_TELEMETRY_ENDPOINT,
headers={"Content-Type": "application/json"},
json=payload,
)
except Exception:
# This way it silences all thread level logging as well
pass
# Run in separate thread with the same context as the current thread
# This is to ensure that the thread gets the current tenant ID
current_context = contextvars.copy_context()
thread = threading.Thread(
target=lambda: current_context.run(telemetry_logic), daemon=True
)
# Run in separate thread to have minimal overhead in main flows
thread = threading.Thread(target=telemetry_logic, daemon=True)
thread.start()
except Exception:
# Should never interfere with normal functions of Onyx

View File

@@ -81,7 +81,6 @@ hubspot-api-client==8.1.0
asana==5.0.8
dropbox==11.36.2
boto3-stubs[s3]==1.34.133
shapely==2.0.6
stripe==10.12.0
urllib3==2.2.3
mistune==0.8.4

View File

@@ -197,7 +197,7 @@ ai_platform_doc = SeedPresaveDocument(
)
customer_support_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/support",
url="https://docs.onyx.app/more/use_cases/customer_support",
title=customer_support_title,
content=customer_support,
title_embedding=model.encode(f"search_document: {customer_support_title}"),

View File

@@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
from onyx.db.document import delete_documents_complete__no_commit
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.search_settings import get_active_search_settings
# Modify sys.path
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -39,6 +38,7 @@ from onyx.db.connector_credential_pair import (
from onyx.db.engine import get_session_context_manager
from onyx.document_index.factory import get_default_document_index
from onyx.file_store.file_store import get_default_file_store
from onyx.document_index.document_index_utils import get_both_index_names
# pylint: enable=E402
# flake8: noqa: E402
@@ -191,10 +191,9 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
)
try:
logger.notice("Deleting information from Vespa and Postgres")
active_search_settings = get_active_search_settings(db_session)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
active_search_settings.primary,
active_search_settings.secondary,
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
files_deleted_count = _unsafe_deletion(

View File

@@ -21,144 +21,35 @@ Options:
--doc-id : Document ID
--fields : Fields to update (JSON)
Example:
Example: (gets docs for a given tenant id and connector id)
python vespa_debug_tool.py --action list_docs --tenant-id my_tenant --connector-id 1 --n 5
"""
import argparse
import json
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import SearchRequest
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.document_index_utils import get_document_chunk_ids
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
from onyx.document_index.vespa.index import VespaIndex
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import SEARCH_ENDPOINT
from onyx.document_index.vespa_constants import SOURCE_TYPE
from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import VESPA_APP_CONTAINER_URL
from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
class DocumentFilter(BaseModel):
# Document filter for link matching.
link: str | None = None
def build_vespa_filters(
filters: IndexFilters,
*,
include_hidden: bool = False,
remove_trailing_and: bool = False,
) -> str:
# Build a combined Vespa filter string from the given IndexFilters.
def _build_or_filters(key: str, vals: list[str] | None) -> str:
if vals is None:
return ""
valid_vals = [val for val in vals if val]
if not key or not valid_vals:
return ""
eq_elems = [f'{key} contains "{elem}"' for elem in valid_vals]
or_clause = " or ".join(eq_elems)
return f"({or_clause})"
def _build_time_filter(
cutoff: datetime | None,
untimed_doc_cutoff: timedelta = timedelta(days=92),
) -> str:
if not cutoff:
return ""
include_untimed = datetime.now(timezone.utc) - untimed_doc_cutoff > cutoff
cutoff_secs = int(cutoff.timestamp())
if include_untimed:
return f"!({DOC_UPDATED_AT} < {cutoff_secs})"
return f"({DOC_UPDATED_AT} >= {cutoff_secs})"
filter_str = ""
if not include_hidden:
filter_str += f"AND !({HIDDEN}=true) "
if filters.tenant_id and MULTI_TENANT:
filter_str += f'AND ({TENANT_ID} contains "{filters.tenant_id}") '
if filters.access_control_list is not None:
acl_str = _build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list)
if acl_str:
filter_str += f"AND {acl_str} "
source_strs = (
[s.value for s in filters.source_type] if filters.source_type else None
)
source_str = _build_or_filters(SOURCE_TYPE, source_strs)
if source_str:
filter_str += f"AND {source_str} "
tags = filters.tags
if tags:
tag_attributes = [tag.tag_key + INDEX_SEPARATOR + tag.tag_value for tag in tags]
else:
tag_attributes = None
tag_str = _build_or_filters(METADATA_LIST, tag_attributes)
if tag_str:
filter_str += f"AND {tag_str} "
doc_set_str = _build_or_filters(DOCUMENT_SETS, filters.document_set)
if doc_set_str:
filter_str += f"AND {doc_set_str} "
time_filter = _build_time_filter(filters.time_cutoff)
if time_filter:
filter_str += f"AND {time_filter} "
if remove_trailing_and:
while filter_str.endswith(" and "):
filter_str = filter_str[:-5]
while filter_str.endswith("AND "):
filter_str = filter_str[:-4]
return filter_str.strip()
# Print Vespa configuration URLs
def print_vespa_config() -> None:
# Print Vespa configuration.
logger.info("Printing Vespa configuration.")
print(f"Vespa Application Endpoint: {VESPA_APPLICATION_ENDPOINT}")
print(f"Vespa App Container URL: {VESPA_APP_CONTAINER_URL}")
print(f"Vespa Search Endpoint: {SEARCH_ENDPOINT}")
print(f"Vespa Document ID Endpoint: {DOCUMENT_ID_ENDPOINT}")
# Check connectivity to Vespa endpoints
def check_vespa_connectivity() -> None:
# Check connectivity to Vespa endpoints.
logger.info("Checking Vespa connectivity.")
endpoints = [
f"{VESPA_APPLICATION_ENDPOINT}/ApplicationStatus",
f"{VESPA_APPLICATION_ENDPOINT}/tenant",
@@ -170,21 +61,17 @@ def check_vespa_connectivity() -> None:
try:
with get_vespa_http_client() as client:
response = client.get(endpoint)
logger.info(
f"Connected to Vespa at {endpoint}, status code {response.status_code}"
)
print(f"Successfully connected to Vespa at {endpoint}")
print(f"Status code: {response.status_code}")
print(f"Response: {response.text[:200]}...")
except Exception as e:
logger.error(f"Failed to connect to Vespa at {endpoint}: {str(e)}")
print(f"Failed to connect to Vespa at {endpoint}: {str(e)}")
print("Vespa connectivity check completed.")
# Get info about the default Vespa application
def get_vespa_info() -> Dict[str, Any]:
# Get info about the default Vespa application.
url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/application/default"
with get_vespa_http_client() as client:
response = client.get(url)
@@ -192,298 +79,121 @@ def get_vespa_info() -> Dict[str, Any]:
return response.json()
def get_index_name(tenant_id: str) -> str:
# Return the index name for a given tenant.
# Get index name for a tenant and connector pair
def get_index_name(tenant_id: str, connector_id: int) -> str:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(db_session, connector_id)
if not cc_pair:
raise ValueError(f"No connector found for id {connector_id}")
search_settings = get_current_search_settings(db_session)
if not search_settings:
raise ValueError(f"No search settings found for tenant {tenant_id}")
return search_settings.index_name
return search_settings.index_name if search_settings else "public"
def query_vespa(
yql: str, tenant_id: Optional[str] = None, limit: int = 10
) -> List[Dict[str, Any]]:
# Perform a Vespa query using YQL syntax.
filters = IndexFilters(tenant_id=tenant_id, access_control_list=[])
filter_string = build_vespa_filters(filters, remove_trailing_and=True)
full_yql = yql.strip()
if filter_string:
full_yql = f"{full_yql} {filter_string}"
full_yql = f"{full_yql} limit {limit}"
params = {"yql": full_yql, "timeout": "10s"}
search_request = SearchRequest(query="", limit=limit, offset=0)
params.update(search_request.model_dump())
logger.info(f"Executing Vespa query: {full_yql}")
# Perform a Vespa query using YQL syntax
def query_vespa(yql: str) -> List[Dict[str, Any]]:
params = {
"yql": yql,
"timeout": "10s",
}
with get_vespa_http_client() as client:
response = client.get(SEARCH_ENDPOINT, params=params)
response.raise_for_status()
result = response.json()
documents = result.get("root", {}).get("children", [])
logger.info(f"Found {len(documents)} documents from query.")
return documents
return response.json()["root"]["children"]
# Get first N documents
def get_first_n_documents(n: int = 10) -> List[Dict[str, Any]]:
# Get the first n documents from any source.
yql = "select * from sources * where true"
return query_vespa(yql, limit=n)
yql = f"select * from sources * where true limit {n};"
return query_vespa(yql)
# Pretty-print a list of documents
def print_documents(documents: List[Dict[str, Any]]) -> None:
# Pretty-print a list of documents.
for doc in documents:
print(json.dumps(doc, indent=2))
print("-" * 80)
# Get and print documents for a specific tenant and connector
def get_documents_for_tenant_connector(
tenant_id: str, connector_id: int, n: int = 10
) -> None:
# Get and print documents for a specific tenant and connector.
index_name = get_index_name(tenant_id)
logger.info(
f"Fetching documents for tenant={tenant_id}, connector_id={connector_id}"
)
yql = f"select * from sources {index_name} where true"
documents = query_vespa(yql, tenant_id, limit=n)
print(
f"First {len(documents)} documents for tenant {tenant_id}, connector {connector_id}:"
)
get_index_name(tenant_id, connector_id)
documents = get_first_n_documents(n)
print(f"First {n} documents for tenant {tenant_id}, connector {connector_id}:")
print_documents(documents)
# Search documents for a specific tenant and connector
def search_documents(
tenant_id: str, connector_id: int, query: str, n: int = 10
) -> None:
# Search documents for a specific tenant and connector.
index_name = get_index_name(tenant_id)
logger.info(
f"Searching documents for tenant={tenant_id}, connector_id={connector_id}, query='{query}'"
)
yql = f"select * from sources {index_name} where userInput(@query)"
documents = query_vespa(yql, tenant_id, limit=n)
print(f"Search results for query '{query}' in tenant {tenant_id}:")
index_name = get_index_name(tenant_id, connector_id)
yql = f"select * from sources {index_name} where userInput(@query) limit {n};"
documents = query_vespa(yql)
print(f"Search results for query '{query}':")
print_documents(documents)
# Update a specific document
def update_document(
tenant_id: str, connector_id: int, doc_id: str, fields: Dict[str, Any]
) -> None:
# Update a specific document.
index_name = get_index_name(tenant_id)
logger.info(
f"Updating document doc_id={doc_id} in tenant={tenant_id}, connector_id={connector_id}"
)
index_name = get_index_name(tenant_id, connector_id)
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) + f"/{doc_id}"
update_request = {"fields": {k: {"assign": v} for k, v in fields.items()}}
with get_vespa_http_client() as client:
response = client.put(url, json=update_request)
response.raise_for_status()
logger.info(f"Document {doc_id} updated successfully.")
print(f"Document {doc_id} updated successfully")
# Delete a specific document
def delete_document(tenant_id: str, connector_id: int, doc_id: str) -> None:
# Delete a specific document.
index_name = get_index_name(tenant_id)
logger.info(
f"Deleting document doc_id={doc_id} in tenant={tenant_id}, connector_id={connector_id}"
)
index_name = get_index_name(tenant_id, connector_id)
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) + f"/{doc_id}"
with get_vespa_http_client() as client:
response = client.delete(url)
response.raise_for_status()
logger.info(f"Document {doc_id} deleted successfully.")
print(f"Document {doc_id} deleted successfully")
def list_documents(n: int = 10, tenant_id: Optional[str] = None) -> None:
# List documents from any source, filtered by tenant if provided.
logger.info(f"Listing up to {n} documents for tenant={tenant_id or 'ALL'}")
yql = "select * from sources * where true"
if tenant_id:
yql += f" and tenant_id contains '{tenant_id}'"
documents = query_vespa(yql, tenant_id=tenant_id, limit=n)
print(f"Total documents found: {len(documents)}")
logger.info(f"Total documents found: {len(documents)}")
print(f"First {min(n, len(documents))} documents:")
for doc in documents[:n]:
print(json.dumps(doc, indent=2))
# List documents from any source
def list_documents(n: int = 10) -> None:
yql = f"select * from sources * where true limit {n};"
url = f"{VESPA_APP_CONTAINER_URL}/search/"
params = {
"yql": yql,
"timeout": "10s",
}
try:
with get_vespa_http_client() as client:
response = client.get(url, params=params)
response.raise_for_status()
documents = response.json()["root"]["children"]
print(f"First {n} documents:")
print_documents(documents)
except Exception as e:
print(f"Failed to list documents: {str(e)}")
# Get and print ACLs for documents of a specific tenant and connector
def get_document_acls(tenant_id: str, connector_id: int, n: int = 10) -> None:
index_name = get_index_name(tenant_id, connector_id)
yql = f"select documentid, access_control_list from sources {index_name} where true limit {n};"
documents = query_vespa(yql)
print(f"ACLs for {n} documents from tenant {tenant_id}, connector {connector_id}:")
for doc in documents:
print(f"Document ID: {doc['fields']['documentid']}")
print(
f"ACL: {json.dumps(doc['fields'].get('access_control_list', {}), indent=2)}"
)
print("-" * 80)
def get_document_and_chunk_counts(
tenant_id: str, cc_pair_id: int, filter_doc: DocumentFilter | None = None
) -> Dict[str, int]:
# Return a dict mapping each document ID to its chunk count for a given connector.
with get_session_with_tenant(tenant_id=tenant_id) as session:
doc_ids_data = (
session.query(DocumentByConnectorCredentialPair.id, Document.link)
.join(
ConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.join(Document, DocumentByConnectorCredentialPair.id == Document.id)
.filter(ConnectorCredentialPair.id == cc_pair_id)
.distinct()
.all()
)
doc_ids = []
for doc_id, link in doc_ids_data:
if filter_doc and filter_doc.link:
if link and filter_doc.link.lower() in link.lower():
doc_ids.append(doc_id)
else:
doc_ids.append(doc_id)
chunk_counts_data = (
session.query(Document.id, Document.chunk_count)
.filter(Document.id.in_(doc_ids))
.all()
)
return {
doc_id: chunk_count
for doc_id, chunk_count in chunk_counts_data
if chunk_count is not None
}
def get_chunk_ids_for_connector(
tenant_id: str,
cc_pair_id: int,
index_name: str,
filter_doc: DocumentFilter | None = None,
) -> List[UUID]:
# Return chunk IDs for a given connector.
doc_id_to_new_chunk_cnt = get_document_and_chunk_counts(
tenant_id, cc_pair_id, filter_doc
)
doc_infos: List[EnrichedDocumentIndexingInfo] = [
VespaIndex.enrich_basic_chunk_info(
index_name=index_name,
http_client=get_vespa_http_client(),
document_id=doc_id,
previous_chunk_count=doc_id_to_new_chunk_cnt.get(doc_id, 0),
new_chunk_count=0,
)
for doc_id in doc_id_to_new_chunk_cnt.keys()
]
chunk_ids = get_document_chunk_ids(
enriched_document_info_list=doc_infos,
tenant_id=tenant_id,
large_chunks_enabled=False,
)
if not isinstance(chunk_ids, list):
raise ValueError(f"Expected list of chunk IDs, got {type(chunk_ids)}")
return chunk_ids
def get_document_acls(
tenant_id: str,
cc_pair_id: int,
n: int | None = 10,
filter_doc: DocumentFilter | None = None,
) -> None:
# Fetch document ACLs for the given tenant and connector pair.
index_name = get_index_name(tenant_id)
logger.info(
f"Fetching document ACLs for tenant={tenant_id}, cc_pair_id={cc_pair_id}"
)
chunk_ids: List[UUID] = get_chunk_ids_for_connector(
tenant_id, cc_pair_id, index_name, filter_doc
)
vespa_client = get_vespa_http_client()
target_ids = chunk_ids if n is None else chunk_ids[:n]
logger.info(
f"Found {len(chunk_ids)} chunk IDs, showing ACLs for {len(target_ids)}."
)
for doc_chunk_id in target_ids:
document_url = (
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{str(doc_chunk_id)}"
)
response = vespa_client.get(document_url)
if response.status_code == 200:
fields = response.json().get("fields", {})
document_id = fields.get("document_id") or fields.get(
"documentid", "Unknown"
)
acls = fields.get("access_control_list", {})
title = fields.get("title", "")
source_type = fields.get("source_type", "")
source_links_raw = fields.get("source_links", "{}")
try:
source_links = json.loads(source_links_raw)
except json.JSONDecodeError:
source_links = {}
print(f"Document Chunk ID: {doc_chunk_id}")
print(f"Document ID: {document_id}")
print(f"ACLs:\n{json.dumps(acls, indent=2)}")
print(f"Source Links: {source_links}")
print(f"Title: {title}")
print(f"Source Type: {source_type}")
if MULTI_TENANT:
print(f"Tenant ID: {fields.get('tenant_id', 'N/A')}")
print("-" * 80)
else:
logger.error(f"Failed to fetch document for chunk ID: {doc_chunk_id}")
print(f"Failed to fetch document for chunk ID: {doc_chunk_id}")
print(f"Status Code: {response.status_code}")
print("-" * 80)
class VespaDebugging:
# Class for managing Vespa debugging actions.
def __init__(self, tenant_id: str | None = None):
self.tenant_id = POSTGRES_DEFAULT_SCHEMA if not tenant_id else tenant_id
def print_config(self) -> None:
# Print Vespa config.
print_vespa_config()
def check_connectivity(self) -> None:
# Check Vespa connectivity.
check_vespa_connectivity()
def list_documents(self, n: int = 10) -> None:
# List documents for a tenant.
list_documents(n, self.tenant_id)
def search_documents(self, connector_id: int, query: str, n: int = 10) -> None:
# Search documents for a tenant and connector.
search_documents(self.tenant_id, connector_id, query, n)
def update_document(
self, connector_id: int, doc_id: str, fields: Dict[str, Any]
) -> None:
# Update a document.
update_document(self.tenant_id, connector_id, doc_id, fields)
def delete_document(self, connector_id: int, doc_id: str) -> None:
# Delete a document.
delete_document(self.tenant_id, connector_id, doc_id)
def acls_by_link(self, cc_pair_id: int, link: str) -> None:
# Get ACLs for a document matching a link.
get_document_acls(
self.tenant_id, cc_pair_id, n=None, filter_doc=DocumentFilter(link=link)
)
def acls(self, cc_pair_id: int, n: int | None = 10) -> None:
# Get ACLs for a connector.
get_document_acls(self.tenant_id, cc_pair_id, n)
def main() -> None:
# Main CLI entry point.
parser = argparse.ArgumentParser(description="Vespa debugging tool")
parser.add_argument(
"--action",
@@ -499,45 +209,60 @@ def main() -> None:
required=True,
help="Action to perform",
)
parser.add_argument("--tenant-id", help="Tenant ID")
parser.add_argument("--connector-id", type=int, help="Connector ID")
parser.add_argument(
"--n", type=int, default=10, help="Number of documents to retrieve"
"--tenant-id", help="Tenant ID (for update, delete, and get_acls actions)"
)
parser.add_argument(
"--connector-id",
type=int,
help="Connector ID (for update, delete, and get_acls actions)",
)
parser.add_argument(
"--n",
type=int,
default=10,
help="Number of documents to retrieve (for list_docs, search, update, and get_acls actions)",
)
parser.add_argument("--query", help="Search query (for search action)")
parser.add_argument("--doc-id", help="Document ID (for update and delete actions)")
parser.add_argument(
"--fields", help="Fields to update, in JSON format (for update)"
"--fields", help="Fields to update, in JSON format (for update action)"
)
args = parser.parse_args()
vespa_debug = VespaDebugging(args.tenant_id)
if args.action == "config":
vespa_debug.print_config()
print_vespa_config()
elif args.action == "connect":
vespa_debug.check_connectivity()
check_vespa_connectivity()
elif args.action == "list_docs":
vespa_debug.list_documents(args.n)
elif args.action == "search":
if not args.query or args.connector_id is None:
parser.error("--query and --connector-id are required for search action")
vespa_debug.search_documents(args.connector_id, args.query, args.n)
elif args.action == "update":
if not args.doc_id or not args.fields or args.connector_id is None:
parser.error(
"--doc-id, --fields, and --connector-id are required for update action"
# If tenant_id and connector_id are provided, list docs for that tenant/connector.
# Otherwise, list documents from any source.
if args.tenant_id and args.connector_id:
get_documents_for_tenant_connector(
args.tenant_id, args.connector_id, args.n
)
else:
list_documents(args.n)
elif args.action == "search":
if not args.query:
parser.error("--query is required for search action")
search_documents(args.tenant_id, args.connector_id, args.query, args.n)
elif args.action == "update":
if not args.doc_id or not args.fields:
parser.error("--doc-id and --fields are required for update action")
fields = json.loads(args.fields)
vespa_debug.update_document(args.connector_id, args.doc_id, fields)
update_document(args.tenant_id, args.connector_id, args.doc_id, fields)
elif args.action == "delete":
if not args.doc_id or args.connector_id is None:
parser.error("--doc-id and --connector-id are required for delete action")
vespa_debug.delete_document(args.connector_id, args.doc_id)
if not args.doc_id:
parser.error("--doc-id is required for delete action")
delete_document(args.tenant_id, args.connector_id, args.doc_id)
elif args.action == "get_acls":
if args.connector_id is None:
parser.error("--connector-id is required for get_acls action")
vespa_debug.acls(args.connector_id, args.n)
if not args.tenant_id or args.connector_id is None:
parser.error(
"--tenant-id and --connector-id are required for get_acls action"
)
get_document_acls(args.tenant_id, args.connector_id, args.n)
if __name__ == "__main__":

Some files were not shown because too many files have changed in this diff Show More