Compare commits

...

1 Commits

Author SHA1 Message Date
justin-tahara
e6dc885422 fix(docprocessing): Local Threading 2025-12-19 18:38:35 -08:00
226 changed files with 2430 additions and 17228 deletions

View File

@@ -33,11 +33,6 @@ env:
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN }}
GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC }}
GITHUB_ADMIN_EMAIL: ${{ secrets.ONYX_GITHUB_ADMIN_EMAIL }}
GITHUB_TEST_USER_1_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_1_EMAIL }}
GITHUB_TEST_USER_2_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_2_EMAIL }}
jobs:
discover-test-dirs:
@@ -404,11 +399,6 @@ jobs:
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN} \
-e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC} \
-e GITHUB_ADMIN_EMAIL=${GITHUB_ADMIN_EMAIL} \
-e GITHUB_TEST_USER_1_EMAIL=${GITHUB_TEST_USER_1_EMAIL} \
-e GITHUB_TEST_USER_2_EMAIL=${GITHUB_TEST_USER_2_EMAIL} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \

View File

@@ -8,66 +8,30 @@ repos:
# From: https://github.com/astral-sh/uv-pre-commit/pull/53/commits/d30b4298e4fb63ce8609e29acdbcf4c9018a483c
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
hooks:
- id: uv-run
name: Check lazy imports
args: ["--with=onyx-devtools", "ods", "check-lazy-imports"]
files: ^backend/(?!\.venv/).*\.py$
- id: uv-sync
args: ["--active", "--locked", "--all-extras"]
args: ["--locked", "--all-extras"]
- id: uv-lock
files: ^pyproject\.toml$
- id: uv-export
name: uv-export default.txt
args:
[
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--extra",
"backend",
"-o",
"backend/requirements/default.txt",
]
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "backend", "-o", "backend/requirements/default.txt"]
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
- id: uv-export
name: uv-export dev.txt
args:
[
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--extra",
"dev",
"-o",
"backend/requirements/dev.txt",
]
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "dev", "-o", "backend/requirements/dev.txt"]
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
- id: uv-export
name: uv-export ee.txt
args:
[
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--extra",
"ee",
"-o",
"backend/requirements/ee.txt",
]
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "ee", "-o", "backend/requirements/ee.txt"]
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
- id: uv-export
name: uv-export model_server.txt
args:
[
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--extra",
"model_server",
"-o",
"backend/requirements/model_server.txt",
]
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "model_server", "-o", "backend/requirements/model_server.txt"]
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
- id: uv-run
name: Check lazy imports
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
files: ^backend/(?!\.venv/).*\.py$
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
# - id: uv-run
# name: mypy
@@ -76,73 +40,68 @@ repos:
# files: ^backend/.*\.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
hooks:
- id: check-yaml
files: ^.github/
- repo: https://github.com/rhysd/actionlint
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
hooks:
- id: actionlint
- repo: https://github.com/psf/black
rev: 8a737e727ac5ab2f1d4cf5876720ed276dc8dc4b # frozen: 25.1.0
hooks:
- id: black
language_version: python3.11
- id: black
language_version: python3.11
# this is a fork which keeps compatibility with black
- repo: https://github.com/wimglenn/reorder-python-imports-black
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
hooks:
- id: reorder-python-imports
args: ["--py311-plus", "--application-directories=backend/"]
# need to ignore alembic files, since reorder-python-imports gets confused
# and thinks that alembic is a local package since there is a folder
# in the backend directory called `alembic`
exclude: ^backend/alembic/
- id: reorder-python-imports
args: ['--py311-plus', '--application-directories=backend/']
# need to ignore alembic files, since reorder-python-imports gets confused
# and thinks that alembic is a local package since there is a folder
# in the backend directory called `alembic`
exclude: ^backend/alembic/
# These settings will remove unused imports with side effects
# Note: The repo currently does not and should not have imports with side effects
- repo: https://github.com/PyCQA/autoflake
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
hooks:
- id: autoflake
args:
[
"--remove-all-unused-imports",
"--remove-unused-variables",
"--in-place",
"--recursive",
]
args: [ '--remove-all-unused-imports', '--remove-unused-variables', '--in-place' , '--recursive']
- repo: https://github.com/golangci/golangci-lint
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
hooks:
- id: golangci-lint
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-prettier
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
hooks:
- id: prettier
types_or: [html, css, javascript, ts, tsx]
language_version: system
- id: prettier
types_or: [html, css, javascript, ts, tsx]
language_version: system
- repo: https://github.com/sirwart/ripsecrets
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
hooks:
- id: ripsecrets
args:
- --additional-pattern
- ^sk-[A-Za-z0-9_\-]{20,}$
- --additional-pattern
- ^sk-[A-Za-z0-9_\-]{20,}$
- repo: local
hooks:

View File

@@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "23957775e5f5"
down_revision = "bc9771dccadf"
branch_labels = None
depends_on = None
branch_labels = None # type: ignore
depends_on = None # type: ignore
def upgrade() -> None:

View File

@@ -42,13 +42,13 @@ def upgrade() -> None:
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
)
@@ -63,13 +63,13 @@ def upgrade() -> None:
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
sa.ForeignKeyConstraint(

View File

@@ -257,8 +257,8 @@ def _migrate_files_to_external_storage() -> None:
print(f"File {file_id} not found in PostgreSQL storage.")
continue
lobj_id = cast(int, file_record.lobj_oid)
file_metadata = cast(Any, file_record.file_metadata)
lobj_id = cast(int, file_record.lobj_oid) # type: ignore
file_metadata = cast(Any, file_record.file_metadata) # type: ignore
# Read file content from PostgreSQL
try:
@@ -280,7 +280,7 @@ def _migrate_files_to_external_storage() -> None:
else:
# Convert other types to dict if possible, otherwise None
try:
file_metadata = dict(file_record.file_metadata)
file_metadata = dict(file_record.file_metadata) # type: ignore
except (TypeError, ValueError):
file_metadata = None

View File

@@ -11,8 +11,8 @@ import sqlalchemy as sa
revision = "e209dc5a8156"
down_revision = "48d14957fe80"
branch_labels = None
depends_on = None
branch_labels = None # type: ignore
depends_on = None # type: ignore
def upgrade() -> None:

View File

@@ -8,7 +8,7 @@ Create Date: 2025-11-28 11:15:37.667340
from alembic import op
import sqlalchemy as sa
from onyx.db.enums import (
from onyx.db.enums import ( # type: ignore[import-untyped]
MCPTransport,
MCPAuthenticationType,
MCPAuthenticationPerformer,

View File

@@ -82,9 +82,9 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore[arg-type]
target_metadata=target_metadata, # type: ignore
include_object=include_object,
)
) # type: ignore
with context.begin_transaction():
context.run_migrations()

View File

@@ -118,6 +118,6 @@ def fetch_document_sets(
.all()
)
document_set_with_cc_pairs.append((document_set, cc_pairs))
document_set_with_cc_pairs.append((document_set, cc_pairs)) # type: ignore
return document_set_with_cc_pairs

View File

@@ -6,7 +6,7 @@ import numpy as np
import torch
import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download
from huggingface_hub import snapshot_download # type: ignore
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
from model_server.constants import MODEL_WARM_UP_STRING
@@ -36,8 +36,8 @@ from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
if TYPE_CHECKING:
from setfit import SetFitModel # type: ignore[import-untyped]
from transformers import PreTrainedTokenizer, BatchEncoding
from setfit import SetFitModel # type: ignore
from transformers import PreTrainedTokenizer, BatchEncoding # type: ignore
logger = setup_logger()

View File

@@ -42,7 +42,7 @@ def get_embedding_model(
Loads or returns a cached SentenceTransformer, sets max_seq_length, pins device,
pre-warms rotary caches once, and wraps encode() with a lock to avoid cache races.
"""
from sentence_transformers import SentenceTransformer
from sentence_transformers import SentenceTransformer # type: ignore
def _prewarm_rope(st_model: "SentenceTransformer", target_len: int) -> None:
"""
@@ -91,7 +91,7 @@ def get_local_reranking_model(
model_name: str,
) -> "CrossEncoder":
global _RERANK_MODEL
from sentence_transformers import CrossEncoder
from sentence_transformers import CrossEncoder # type: ignore
if _RERANK_MODEL is None:
logger.notice(f"Loading {model_name}")
@@ -195,7 +195,7 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
# Run CPU-bound reranking in a thread pool
return await asyncio.get_event_loop().run_in_executor(
None,
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(),
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
)

View File

@@ -12,7 +12,7 @@ from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from transformers import logging as transformer_logging
from transformers import logging as transformer_logging # type:ignore
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_information_content_model

View File

@@ -8,7 +8,7 @@ import torch.nn as nn
if TYPE_CHECKING:
from transformers import DistilBertConfig
from transformers import DistilBertConfig # type: ignore
class HybridClassifier(nn.Module):
@@ -34,7 +34,7 @@ class HybridClassifier(nn.Module):
query_ids: torch.Tensor,
query_mask: torch.Tensor,
) -> dict[str, torch.Tensor]:
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask)
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore
sequence_output = outputs.last_hidden_state
# Intent classification on the CLS token
@@ -102,7 +102,7 @@ class ConnectorClassifier(nn.Module):
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.distilbert(
hidden_states = self.distilbert( # type: ignore
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state

View File

@@ -43,7 +43,7 @@ def get_access_for_document(
versioned_get_access_for_document_fn = fetch_versioned_implementation(
"onyx.access.access", "_get_access_for_document"
)
return versioned_get_access_for_document_fn(document_id, db_session)
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
def get_null_document_access() -> DocumentAccess:
@@ -93,7 +93,9 @@ def get_access_for_documents(
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
"onyx.access.access", "_get_access_for_documents"
)
return versioned_get_access_for_documents_fn(document_ids, db_session)
return versioned_get_access_for_documents_fn(
document_ids, db_session
) # type: ignore
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
@@ -111,7 +113,7 @@ def get_acl_for_user(user: User | None, db_session: Session | None = None) -> se
versioned_acl_for_user_fn = fetch_versioned_implementation(
"onyx.access.access", "_get_acl_for_user"
)
return versioned_acl_for_user_fn(user, db_session)
return versioned_acl_for_user_fn(user, db_session) # type: ignore
def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> bool:

View File

@@ -338,7 +338,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_created = False
try:
user = await super().create(user_create, safe=safe, request=request)
user = await super().create(
user_create, safe=safe, request=request
) # type: ignore
user_created = True
except IntegrityError as error:
# Race condition: another request created the same user after the
@@ -602,7 +604,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
if user.oidc_expiry is not None and not TRACK_EXTERNAL_IDP_EXPIRY:
if (
user.oidc_expiry is not None # type: ignore
and not TRACK_EXTERNAL_IDP_EXPIRY
):
await self.user_db.update(user, {"oidc_expiry": None})
user.oidc_expiry = None # type: ignore
remove_user_from_invited_users(user.email)
@@ -1173,7 +1178,7 @@ async def _sync_jwt_oidc_expiry(
return
await user_manager.user_db.update(user, {"oidc_expiry": oidc_expiry})
user.oidc_expiry = oidc_expiry
user.oidc_expiry = oidc_expiry # type: ignore
return
if user.oidc_expiry is not None:

View File

@@ -1,135 +0,0 @@
from uuid import uuid4
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import SearchSettings
def try_creating_docfetching_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
Now uses database-based coordination instead of Redis fencing.
"""
LOCK_TIMEOUT = 30
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
index_attempt_id = None
try:
# Basic status checks
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# Generate custom task ID for tracking
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
# Try to create a new index attempt using database coordination
# This replaces the Redis fencing mechanism
index_attempt_id = IndexingCoordination.try_create_index_attempt(
db_session=db_session,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
celery_task_id=custom_task_id,
from_beginning=reindex,
)
if index_attempt_id is None:
# Another indexing attempt is already running
return None
# Determine which queue to use based on whether this is a user file
# TODO: at the moment the indexing pipeline is
# shared between user files and connectors
queue = (
OnyxCeleryQueues.USER_FILES_INDEXING
if cc_pair.is_user_file
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
)
# Use higher priority for first-time indexing to ensure new connectors
# get processed before re-indexing of existing connectors
has_successful_attempt = cc_pair.last_successful_index_time is not None
priority = (
OnyxCeleryPriority.MEDIUM
if has_successful_attempt
else OnyxCeleryPriority.HIGH
)
# Send the task to Celery
result = celery_app.send_task(
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
tenant_id=tenant_id,
),
queue=queue,
task_id=custom_task_id,
priority=priority,
)
if not result:
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
task_logger.info(
f"Created docfetching task: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id} "
f"attempt_id={index_attempt_id} "
f"celery_task_id={custom_task_id}"
)
return index_attempt_id
except Exception:
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
# Clean up on failure
if index_attempt_id is not None:
mark_attempt_failed(index_attempt_id, db_session)
return None
finally:
if lock.owned():
lock.release()
return index_attempt_id

View File

@@ -25,14 +25,14 @@ from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.celery.tasks.docfetching.task_creation_utils import (
try_creating_docfetching_task,
)
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
from onyx.background.celery.tasks.docprocessing.heartbeat import stop_heartbeat
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallback
from onyx.background.celery.tasks.docprocessing.utils import is_in_repeated_error_state
from onyx.background.celery.tasks.docprocessing.utils import should_index
from onyx.background.celery.tasks.docprocessing.utils import (
try_creating_docfetching_task,
)
from onyx.background.celery.tasks.models import DocProcessingContext
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
from onyx.background.indexing.checkpointing_utils import (

View File

@@ -1,15 +1,22 @@
import time
from datetime import datetime
from datetime import timezone
from uuid import uuid4
from celery import Celery
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -17,6 +24,8 @@ from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -289,3 +298,112 @@ def should_index(
return False
return True
def try_creating_docfetching_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
Now uses database-based coordination instead of Redis fencing.
"""
LOCK_TIMEOUT = 30
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
index_attempt_id = None
try:
# Basic status checks
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# Generate custom task ID for tracking
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
# Try to create a new index attempt using database coordination
# This replaces the Redis fencing mechanism
index_attempt_id = IndexingCoordination.try_create_index_attempt(
db_session=db_session,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
celery_task_id=custom_task_id,
from_beginning=reindex,
)
if index_attempt_id is None:
# Another indexing attempt is already running
return None
# Determine which queue to use based on whether this is a user file
# TODO: at the moment the indexing pipeline is
# shared between user files and connectors
queue = (
OnyxCeleryQueues.USER_FILES_INDEXING
if cc_pair.is_user_file
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
)
# Send the task to Celery
result = celery_app.send_task(
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
tenant_id=tenant_id,
),
queue=queue,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
)
if not result:
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
task_logger.info(
f"Created docfetching task: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id} "
f"attempt_id={index_attempt_id} "
f"celery_task_id={custom_task_id}"
)
return index_attempt_id
except Exception:
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
# Clean up on failure
if index_attempt_id is not None:
mark_attempt_failed(index_attempt_id, db_session)
return None
finally:
if lock.owned():
lock.release()
return index_attempt_id

View File

@@ -368,19 +368,11 @@ def connector_document_extraction(
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
from_beginning = index_attempt.from_beginning
has_successful_attempt = (
index_attempt.connector_credential_pair.last_successful_index_time
is not None
)
# Use higher priority for first-time indexing to ensure new connectors
# get processed before re-indexing of existing connectors
docprocessing_priority = (
OnyxCeleryPriority.MEDIUM
if has_successful_attempt
else OnyxCeleryPriority.HIGH
)
earliest_index_time = (
db_connector.indexing_start.timestamp()
@@ -503,7 +495,6 @@ def connector_document_extraction(
tenant_id,
app,
most_recent_attempt,
docprocessing_priority,
)
last_batch_num = reissued_batch_count + completed_batches
index_attempt.completed_batches = completed_batches
@@ -616,7 +607,7 @@ def connector_document_extraction(
OnyxCeleryTask.DOCPROCESSING_TASK,
kwargs=processing_batch_data,
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=docprocessing_priority,
priority=OnyxCeleryPriority.MEDIUM,
)
batch_num += 1
@@ -767,7 +758,6 @@ def reissue_old_batches(
tenant_id: str,
app: Celery,
most_recent_attempt: IndexAttempt | None,
priority: OnyxCeleryPriority,
) -> tuple[int, int]:
# When loading from a checkpoint, we need to start new docprocessing tasks
# tied to the new index attempt for any batches left over in the file store
@@ -795,7 +785,7 @@ def reissue_old_batches(
"batch_num": path_info.batch_num, # use same batch num as previously
},
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=priority,
priority=OnyxCeleryPriority.MEDIUM,
)
recent_batches = most_recent_attempt.completed_batches if most_recent_attempt else 0
# resume from the batch num of the last attempt. This should be one more

View File

@@ -6,7 +6,6 @@ from typing import Any
from onyx.chat.emitter import Emitter
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
@@ -87,7 +86,7 @@ class ChatStateContainer:
return self.is_clarification
def run_chat_loop_with_state_containers(
def run_chat_llm_with_state_containers(
func: Callable[..., None],
is_connected: Callable[[], bool],
emitter: Emitter,
@@ -111,7 +110,7 @@ def run_chat_loop_with_state_containers(
**kwargs: Additional keyword arguments for func
Usage:
packets = run_chat_loop_with_state_containers(
packets = run_chat_llm_with_state_containers(
my_func,
emitter=emitter,
state_container=state_container,
@@ -132,7 +131,7 @@ def run_chat_loop_with_state_containers(
# If execution fails, emit an exception packet
emitter.emit(
Packet(
placement=Placement(turn_index=0),
turn_index=0,
obj=PacketException(type="error", exception=e),
)
)

View File

@@ -27,12 +27,11 @@ from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.utils import model_needs_formatting_reenabled
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.interface import Tool
from onyx.tools.models import ToolCallInfo
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
@@ -444,7 +443,7 @@ def run_llm_loop(
tool_definitions=[tool.tool_definition() for tool in final_tools],
tool_choice=tool_choice,
llm=llm,
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
turn_index=llm_cycle_count + reasoning_cycles,
citation_processor=citation_processor,
state_container=state_container,
# The rich docs representation is passed in so that when yielding the answer, it can also
@@ -496,7 +495,7 @@ def run_llm_loop(
raise ValueError("Tool response missing tool_call reference")
tool_call = tool_response.tool_call
tab_index = tool_call.placement.tab_index
tab_index = tool_call.tab_index
# Track if search tool was called (for skipping query expansion on subsequent calls)
if tool_call.tool_name == SearchTool.NAME:
@@ -626,7 +625,7 @@ def run_llm_loop(
emitter.emit(
Packet(
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
turn_index=llm_cycle_count + reasoning_cycles,
obj=OverallStop(type="stop"),
)
)

View File

@@ -38,7 +38,6 @@ from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningDone
from onyx.server.query_and_chat.streaming_models import ReasoningStart
@@ -199,7 +198,6 @@ def _update_tool_call_with_delta(
def _extract_tool_call_kickoffs(
id_to_tool_call_map: dict[int, dict[str, Any]],
turn_index: int,
sub_turn_index: int | None = None,
) -> list[ToolCallKickoff]:
"""Extract ToolCallKickoff objects from the tool call map.
@@ -224,11 +222,8 @@ def _extract_tool_call_kickoffs(
tool_call_id=tool_call_data["id"],
tool_name=tool_call_data["name"],
tool_args=tool_args,
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
turn_index=turn_index,
tab_index=tab_index,
)
)
tab_index += 1
@@ -380,21 +375,12 @@ def translate_history_to_llm_format(
return messages
def _increment_turns(
turn_index: int, sub_turn_index: int | None
) -> tuple[int, int | None]:
if sub_turn_index is None:
return turn_index + 1, None
else:
return turn_index, sub_turn_index + 1
def run_llm_step_pkt_generator(
history: list[ChatMessageSimple],
tool_definitions: list[dict],
tool_choice: ToolChoiceOptions,
llm: LLM,
placement: Placement,
turn_index: int,
state_container: ChatStateContainer,
citation_processor: DynamicCitationProcessor | None,
reasoning_effort: ReasoningEffort | None = None,
@@ -403,58 +389,9 @@ def run_llm_step_pkt_generator(
custom_token_processor: (
Callable[[Delta | None, Any], tuple[Delta | None, Any]] | None
) = None,
) -> Generator[Packet, None, tuple[LlmStepResult, bool]]:
"""Run an LLM step and stream the response as packets.
NOTE: DO NOT TOUCH THIS FUNCTION BEFORE ASKING YUHONG, this is very finicky and
delicate logic that is core to the app's main functionality.
This generator function streams LLM responses, processing reasoning content,
answer content, tool calls, and citations. It yields Packet objects for
real-time streaming to clients and accumulates the final result.
Args:
history: List of chat messages in the conversation history.
tool_definitions: List of tool definitions available to the LLM.
tool_choice: Tool choice configuration (e.g., "auto", "required", "none").
llm: Language model interface to use for generation.
turn_index: Current turn index in the conversation.
state_container: Container for storing chat state (reasoning, answers).
citation_processor: Optional processor for extracting and formatting citations
from the response. If provided, processes tokens to identify citations.
reasoning_effort: Optional reasoning effort configuration for models that
support reasoning (e.g., o1 models).
final_documents: Optional list of search documents to include in the response
start packet.
user_identity: Optional user identity information for the LLM.
custom_token_processor: Optional callable that processes each token delta
before yielding. Receives (delta, processor_state) and returns
(modified_delta, new_processor_state). Can return None for delta to skip.
sub_turn_index: Optional sub-turn index for nested tool/agent calls.
Yields:
Packet: Streaming packets containing:
- ReasoningStart/ReasoningDelta/ReasoningDone for reasoning content
- AgentResponseStart/AgentResponseDelta for answer content
- CitationInfo for extracted citations
- ToolCallKickoff for tool calls (extracted at the end)
Returns:
tuple[LlmStepResult, bool]: A tuple containing:
- LlmStepResult: The final result with accumulated reasoning, answer,
and tool calls (if any).
- bool: Whether reasoning occurred during this step. This should be used to
increment the turn index or sub_turn index for the rest of the LLM loop.
Note:
The function handles incremental state updates, saving reasoning and answer
tokens to the state container as they are generated. Tool calls are extracted
and yielded only after the stream completes.
"""
turn_index = placement.turn_index
tab_index = placement.tab_index
sub_turn_index = placement.sub_turn_index
) -> Generator[Packet, None, tuple[LlmStepResult, int]]:
# The second return value is for the turn index because reasoning counts on the frontend as a turn
# TODO this is maybe ok but does not align well with the backend logic too well
llm_msg_history = translate_history_to_llm_format(history)
has_reasoned = 0
@@ -519,19 +456,11 @@ def run_llm_step_pkt_generator(
state_container.set_reasoning_tokens(accumulated_reasoning)
if not reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
turn_index=turn_index,
obj=ReasoningStart(),
)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
turn_index=turn_index,
obj=ReasoningDelta(reasoning=delta.reasoning_content),
)
reasoning_start = True
@@ -539,26 +468,15 @@ def run_llm_step_pkt_generator(
if delta.content:
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
turn_index=turn_index,
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
if not answer_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
turn_index=turn_index + has_reasoned,
obj=AgentResponseStart(
final_documents=final_documents,
),
@@ -572,20 +490,12 @@ def run_llm_step_pkt_generator(
# Save answer incrementally to state container
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
turn_index=turn_index + has_reasoned,
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
turn_index=turn_index + has_reasoned,
obj=result,
)
else:
@@ -594,28 +504,17 @@ def run_llm_step_pkt_generator(
# Save answer incrementally to state container
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
turn_index=turn_index + has_reasoned,
obj=AgentResponseDelta(content=delta.content),
)
if delta.tool_calls:
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
turn_index=turn_index,
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
for tool_call_delta in delta.tool_calls:
@@ -629,7 +528,7 @@ def run_llm_step_pkt_generator(
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
tool_calls = _extract_tool_call_kickoffs(
id_to_tool_call_map, turn_index, sub_turn_index
id_to_tool_call_map, turn_index + has_reasoned
)
if tool_calls:
tool_calls_list: list[ToolCall] = [
@@ -657,16 +556,15 @@ def run_llm_step_pkt_generator(
tool_calls=None,
)
span_generation.span_data.output = [assistant_msg_no_tools.model_dump()]
# Should have closed the reasoning block, the only pathway to hit this is if the stream
# ended with reasoning content and no other content. This is an invalid state.
# Close reasoning block if still open (stream ended with reasoning content)
if reasoning_start:
raise RuntimeError("Reasoning block is still open but the stream ended.")
yield Packet(
turn_index=turn_index,
obj=ReasoningDone(),
)
has_reasoned = 1
# Flush any remaining content from citation processor
# Reasoning is always first so this should use the post-incremented value of turn_index
# Note that this doesn't need to handle any sub-turns as those docs will not have citations
# as clickable items and will be stripped out instead.
if citation_processor:
for result in citation_processor.process_token(None):
if isinstance(result, str):
@@ -674,20 +572,12 @@ def run_llm_step_pkt_generator(
# Save answer incrementally to state container
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
turn_index=turn_index + has_reasoned,
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
turn_index=turn_index + has_reasoned,
obj=result,
)
@@ -722,7 +612,7 @@ def run_llm_step(
tool_definitions: list[dict],
tool_choice: ToolChoiceOptions,
llm: LLM,
placement: Placement,
turn_index: int,
state_container: ChatStateContainer,
citation_processor: DynamicCitationProcessor | None,
reasoning_effort: ReasoningEffort | None = None,
@@ -742,7 +632,7 @@ def run_llm_step(
tool_definitions=tool_definitions,
tool_choice=tool_choice,
llm=llm,
placement=placement,
turn_index=turn_index,
state_container=state_container,
citation_processor=citation_processor,
reasoning_effort=reasoning_effort,

View File

@@ -8,7 +8,7 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_state import run_chat_loop_with_state_containers
from onyx.chat.chat_state import run_chat_llm_with_state_containers
from onyx.chat.chat_utils import convert_chat_history
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.chat.chat_utils import get_custom_agent_prompt
@@ -65,7 +65,7 @@ from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.utils import get_json_line
from onyx.tools.constants import SEARCH_TOOL_ID
from onyx.tools.interface import Tool
from onyx.tools.tool import Tool
from onyx.tools.tool_constructor import construct_tools
from onyx.tools.tool_constructor import CustomToolConfig
from onyx.tools.tool_constructor import SearchToolConfig
@@ -553,7 +553,7 @@ def stream_chat_message_objects(
# (user has already responded to a clarification question)
skip_clarification = is_last_assistant_message_clarification(chat_history)
yield from run_chat_loop_with_state_containers(
yield from run_chat_llm_with_state_containers(
run_deep_research_llm_loop,
is_connected=check_is_connected,
emitter=emitter,
@@ -568,7 +568,7 @@ def stream_chat_message_objects(
user_identity=user_identity,
)
else:
yield from run_chat_loop_with_state_containers(
yield from run_chat_llm_with_state_containers(
run_llm_loop,
is_connected=check_is_connected, # Not passed through to run_llm_loop
emitter=emitter,

View File

@@ -22,7 +22,7 @@ from onyx.prompts.tool_prompts import PYTHON_TOOL_GUIDANCE
from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
from onyx.tools.interface import Tool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)

View File

@@ -583,16 +583,6 @@ LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8)
MAX_SLACK_QUERY_EXPANSIONS = int(os.environ.get("MAX_SLACK_QUERY_EXPANSIONS", "5"))
# Slack federated search thread context settings
# Batch size for fetching thread context (controls concurrent API calls per batch)
SLACK_THREAD_CONTEXT_BATCH_SIZE = int(
os.environ.get("SLACK_THREAD_CONTEXT_BATCH_SIZE", "5")
)
# Maximum messages to fetch thread context for (top N by relevance get full context)
MAX_SLACK_THREAD_CONTEXT_MESSAGES = int(
os.environ.get("MAX_SLACK_THREAD_CONTEXT_MESSAGES", "5")
)
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)
@@ -708,15 +698,6 @@ AVERAGE_SUMMARY_EMBEDDINGS = (
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
# The intent was to have this be configurable per query, but I don't think any
# codepath was actually configuring this, so for the migrated Vespa interface
# we'll just use the default value, but also have it be configurable by env var.
RECENCY_BIAS_MULTIPLIER = float(os.environ.get("RECENCY_BIAS_MULTIPLIER") or 1.0)
# Should match the rerank-count value set in
# backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd.jinja.
RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
#####
# Tool Configs

View File

@@ -563,7 +563,7 @@ REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
if platform.system() == "Darwin":
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
else:
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore
class OnyxCallTypes(str, Enum):

View File

@@ -38,7 +38,7 @@ class AsanaAPI:
def __init__(
self, api_token: str, workspace_gid: str, team_gid: str | None
) -> None:
self._user = None
self._user = None # type: ignore
self.workspace_gid = workspace_gid
self.team_gid = team_gid

View File

@@ -9,14 +9,14 @@ from typing import Any
from typing import Optional
from urllib.parse import quote
import boto3
from botocore.client import Config
import boto3 # type: ignore
from botocore.client import Config # type: ignore
from botocore.credentials import RefreshableCredentials
from botocore.exceptions import ClientError
from botocore.exceptions import NoCredentialsError
from botocore.exceptions import PartialCredentialsError
from botocore.session import get_session
from mypy_boto3_s3 import S3Client
from mypy_boto3_s3 import S3Client # type: ignore
from onyx.configs.app_configs import BLOB_STORAGE_SIZE_THRESHOLD
from onyx.configs.app_configs import INDEX_BATCH_SIZE

View File

@@ -2,11 +2,11 @@ from datetime import timezone
from io import BytesIO
from typing import Any
from dropbox import Dropbox # type: ignore[import-untyped]
from dropbox.exceptions import ApiError # type: ignore[import-untyped]
from dropbox.exceptions import AuthError
from dropbox.files import FileMetadata # type: ignore[import-untyped]
from dropbox.files import FolderMetadata
from dropbox import Dropbox # type: ignore
from dropbox.exceptions import ApiError # type:ignore
from dropbox.exceptions import AuthError # type:ignore
from dropbox.files import FileMetadata # type:ignore
from dropbox.files import FolderMetadata # type:ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource

View File

@@ -5,8 +5,8 @@ from typing import Any
from typing import cast
from typing import Dict
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from onyx.access.models import ExternalAccess

View File

@@ -14,9 +14,9 @@ from typing import cast
from typing import Protocol
from urllib.parse import urlparse
from google.auth.exceptions import RefreshError
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from google.auth.exceptions import RefreshError # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from typing_extensions import override
@@ -1006,7 +1006,7 @@ class GoogleDriveConnector(
file.user_email,
)
if file.error is None:
file.error = exc
file.error = exc # type: ignore[assignment]
yield file
continue

View File

@@ -1,9 +1,9 @@
import json
from typing import Any
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET

View File

@@ -4,7 +4,7 @@ from urllib.parse import parse_qs
from urllib.parse import ParseResult
from urllib.parse import urlparse
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
@@ -179,7 +179,7 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
)
) # type: ignore
return str(auth_url)

View File

@@ -1,11 +1,11 @@
from collections.abc import Callable
from typing import Any
from google.auth.exceptions import RefreshError
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from googleapiclient.discovery import build # type: ignore[import-untyped]
from googleapiclient.discovery import Resource
from google.auth.exceptions import RefreshError # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.discovery import Resource # type: ignore
from onyx.utils.logger import setup_logger

View File

@@ -10,7 +10,7 @@ from urllib.parse import urlparse
from urllib.parse import urlunparse
from pywikibot import family # type: ignore[import-untyped]
from pywikibot import pagegenerators
from pywikibot import pagegenerators # type: ignore[import-untyped]
from pywikibot.scripts import generate_family_file # type: ignore[import-untyped]
from pywikibot.scripts.generate_user_files import pywikibot # type: ignore[import-untyped]

View File

@@ -10,8 +10,8 @@ from typing import cast
from typing import ClassVar
import pywikibot.time # type: ignore[import-untyped]
from pywikibot import pagegenerators
from pywikibot import textlib
from pywikibot import pagegenerators # type: ignore[import-untyped]
from pywikibot import textlib # type: ignore[import-untyped]
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource

View File

@@ -4,9 +4,9 @@ from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.teams.channels.channel import Channel # type: ignore[import-untyped]
from office365.teams.channels.channel import ConversationMember
from office365.graph_client import GraphClient # type: ignore
from office365.teams.channels.channel import Channel # type: ignore
from office365.teams.channels.channel import ConversationMember # type: ignore
from onyx.access.models import ExternalAccess
from onyx.connectors.interfaces import SecondsSinceUnixEpoch

View File

@@ -21,13 +21,6 @@ class OptionalSearchSetting(str, Enum):
class QueryType(str, Enum):
"""
The type of first-pass query to use for hybrid search.
The values of this enum are injected into the ranking profile name which
should match the name in the schema.
"""
KEYWORD = "keyword"
SEMANTIC = "semantic"

View File

@@ -13,8 +13,6 @@ from slack_sdk.errors import SlackApiError
from sqlalchemy.orm import Session
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
from onyx.configs.app_configs import MAX_SLACK_THREAD_CONTEXT_MESSAGES
from onyx.configs.app_configs import SLACK_THREAD_CONTEXT_BATCH_SIZE
from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import TextSection
@@ -625,55 +623,33 @@ def merge_slack_messages(
return merged_messages, docid_to_message, all_filtered_channels
class SlackRateLimitError(Exception):
"""Raised when Slack API returns a rate limit error (429)."""
class ThreadContextResult:
"""Result wrapper for thread context fetch that captures error type."""
__slots__ = ("text", "is_rate_limited", "is_error")
def __init__(
self, text: str, is_rate_limited: bool = False, is_error: bool = False
):
self.text = text
self.is_rate_limited = is_rate_limited
self.is_error = is_error
@classmethod
def success(cls, text: str) -> "ThreadContextResult":
return cls(text)
@classmethod
def rate_limited(cls, original_text: str) -> "ThreadContextResult":
return cls(original_text, is_rate_limited=True)
@classmethod
def error(cls, original_text: str) -> "ThreadContextResult":
return cls(original_text, is_error=True)
def _fetch_thread_context(
def get_contextualized_thread_text(
message: SlackMessage, access_token: str, team_id: str | None = None
) -> ThreadContextResult:
) -> str:
"""
Fetch thread context for a message, returning a result object.
Retrieves the initial thread message as well as the text following the message
and combines them into a single string. If the slack query fails, returns the
original message text.
Returns ThreadContextResult with:
- success: enriched thread text
- rate_limited: original text + flag indicating we should stop
- error: original text for other failures (graceful degradation)
The idea is that the message (the one that actually matched the search), the
initial thread message, and the replies to the message are important in answering
the user's query.
Args:
message: The SlackMessage to get context for
access_token: Slack OAuth access token
team_id: Slack team ID for caching user profiles (optional but recommended)
"""
channel_id = message.channel_id
thread_id = message.thread_id
message_id = message.message_id
# If not a thread, return original text as success
# if it's not a thread, return the message text
if thread_id is None:
return ThreadContextResult.success(message.text)
return message.text
slack_client = WebClient(token=access_token, timeout=30)
# get the thread messages
slack_client = WebClient(token=access_token)
try:
response = slack_client.conversations_replies(
channel=channel_id,
@@ -682,44 +658,19 @@ def _fetch_thread_context(
response.validate()
messages: list[dict[str, Any]] = response.get("messages", [])
except SlackApiError as e:
# Check for rate limit error specifically
if e.response and e.response.status_code == 429:
logger.warning(
f"Slack rate limit hit while fetching thread context for {channel_id}/{thread_id}"
)
return ThreadContextResult.rate_limited(message.text)
# For other Slack errors, log and return original text
logger.error(f"Slack API error in thread context fetch: {e}")
return ThreadContextResult.error(message.text)
except Exception as e:
# Network errors, timeouts, etc - treat as recoverable error
logger.error(f"Unexpected error in thread context fetch: {e}")
return ThreadContextResult.error(message.text)
logger.error(f"Slack API error in get_contextualized_thread_text: {e}")
return message.text
# If empty response or single message (not a thread), return original text
# make sure we didn't get an empty response or a single message (not a thread)
if len(messages) <= 1:
return ThreadContextResult.success(message.text)
return message.text
# Build thread text from thread starter + context window around matched message
thread_text = _build_thread_text(
messages, message_id, thread_id, access_token, team_id, slack_client
)
return ThreadContextResult.success(thread_text)
def _build_thread_text(
messages: list[dict[str, Any]],
message_id: str,
thread_id: str,
access_token: str,
team_id: str | None,
slack_client: WebClient,
) -> str:
"""Build the thread text from messages."""
# add the initial thread message
msg_text = messages[0].get("text", "")
msg_sender = messages[0].get("user", "")
thread_text = f"<@{msg_sender}>: {msg_text}"
# add the message (unless it's the initial message)
thread_text += "\n\nReplies:"
if thread_id == message_id:
message_id_idx = 0
@@ -730,21 +681,28 @@ def _build_thread_text(
if not message_id_idx:
return thread_text
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
# Include a few messages BEFORE the matched message for context
# This helps understand what the matched message is responding to
start_idx = max(
1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW
) # Start after thread starter
# Add ellipsis if we're skipping messages between thread starter and context window
if start_idx > 1:
thread_text += "\n..."
# Add context messages before the matched message
for i in range(start_idx, message_id_idx):
msg_text = messages[i].get("text", "")
msg_sender = messages[i].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
# Add the matched message itself
msg_text = messages[message_id_idx].get("text", "")
msg_sender = messages[message_id_idx].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
# Add following replies
# add the following replies to the thread text
len_replies = 0
for msg in messages[message_id_idx + 1 :]:
msg_text = msg.get("text", "")
@@ -752,19 +710,22 @@ def _build_thread_text(
reply = f"\n\n<@{msg_sender}>: {msg_text}"
thread_text += reply
# stop if len_replies exceeds chunk_size * 4 chars as the rest likely won't fit
len_replies += len(reply)
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
thread_text += "\n..."
break
# Replace user IDs with names using cached lookups
# replace user ids with names in the thread text using cached lookups
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
if team_id:
# Use cached batch lookup when team_id is available
user_profiles = batch_get_user_profiles(access_token, team_id, userids)
for userid, name in user_profiles.items():
thread_text = thread_text.replace(f"<@{userid}>", name)
else:
# Fallback to individual lookups (no caching) when team_id not available
for userid in userids:
try:
response = slack_client.users_profile_get(user=userid)
@@ -774,7 +735,7 @@ def _build_thread_text(
except SlackApiError as e:
if "user_not_found" in str(e):
logger.debug(
f"User {userid} not found (likely deleted/deactivated)"
f"User {userid} not found in Slack workspace (likely deleted/deactivated)"
)
else:
logger.warning(f"Could not fetch profile for user {userid}: {e}")
@@ -786,115 +747,6 @@ def _build_thread_text(
return thread_text
def fetch_thread_contexts_with_rate_limit_handling(
slack_messages: list[SlackMessage],
access_token: str,
team_id: str | None,
batch_size: int = SLACK_THREAD_CONTEXT_BATCH_SIZE,
max_messages: int | None = MAX_SLACK_THREAD_CONTEXT_MESSAGES,
) -> list[str]:
"""
Fetch thread contexts in controlled batches, stopping on rate limit.
Distinguishes between error types:
- Rate limit (429): Stop processing further batches
- Other errors: Continue processing (graceful degradation)
Args:
slack_messages: Messages to fetch thread context for (should be sorted by relevance)
access_token: Slack OAuth token
team_id: Slack team ID for user profile caching
batch_size: Number of concurrent API calls per batch
max_messages: Maximum messages to fetch thread context for (None = no limit)
Returns:
List of thread texts, one per input message.
Messages beyond max_messages or after rate limit get their original text.
"""
if not slack_messages:
return []
# Limit how many messages we fetch thread context for (if max_messages is set)
if max_messages and max_messages < len(slack_messages):
messages_for_context = slack_messages[:max_messages]
messages_without_context = slack_messages[max_messages:]
else:
messages_for_context = slack_messages
messages_without_context = []
logger.info(
f"Fetching thread context for {len(messages_for_context)} of {len(slack_messages)} messages "
f"(batch_size={batch_size}, max={max_messages or 'unlimited'})"
)
results: list[str] = []
rate_limited = False
total_batches = (len(messages_for_context) + batch_size - 1) // batch_size
rate_limit_batch = 0
# Process in batches
for i in range(0, len(messages_for_context), batch_size):
current_batch = i // batch_size + 1
if rate_limited:
# Skip remaining batches, use original message text
remaining = messages_for_context[i:]
skipped_batches = total_batches - rate_limit_batch
logger.warning(
f"Slack rate limit: skipping {len(remaining)} remaining messages "
f"({skipped_batches} of {total_batches} batches). "
f"Successfully enriched {len(results)} messages before rate limit."
)
results.extend([msg.text for msg in remaining])
break
batch = messages_for_context[i : i + batch_size]
# _fetch_thread_context returns ThreadContextResult (never raises)
# allow_failures=True is a safety net for any unexpected exceptions
batch_results: list[ThreadContextResult | None] = (
run_functions_tuples_in_parallel(
[
(
_fetch_thread_context,
(msg, access_token, team_id),
)
for msg in batch
],
allow_failures=True,
max_workers=batch_size,
)
)
# Process results - ThreadContextResult tells us exactly what happened
for j, result in enumerate(batch_results):
if result is None:
# Unexpected exception (shouldn't happen) - use original text, stop
logger.error(f"Unexpected None result for message {j} in batch")
results.append(batch[j].text)
rate_limited = True
rate_limit_batch = current_batch
elif result.is_rate_limited:
# Rate limit hit - use original text, stop further batches
results.append(result.text)
rate_limited = True
rate_limit_batch = current_batch
else:
# Success or recoverable error - use the text (enriched or original)
results.append(result.text)
if rate_limited:
logger.warning(
f"Slack rate limit (429) hit at batch {current_batch}/{total_batches} "
f"while fetching thread context. Stopping further API calls."
)
# Add original text for messages we didn't fetch context for
results.extend([msg.text for msg in messages_without_context])
return results
def convert_slack_score(slack_score: float) -> float:
"""
Convert slack score to a score between 0 and 1.
@@ -1112,12 +964,11 @@ def slack_retrieval(
if not slack_messages:
return []
# Fetch thread context with rate limit handling and message limiting
# Messages are already sorted by relevance (slack_score), so top N get full context
thread_texts = fetch_thread_contexts_with_rate_limit_handling(
slack_messages=slack_messages,
access_token=access_token,
team_id=team_id,
thread_texts: list[str] = run_functions_tuples_in_parallel(
[
(get_contextualized_thread_text, (slack_message, access_token, team_id))
for slack_message in slack_messages
]
)
for slack_message, thread_text in zip(slack_messages, thread_texts):
slack_message.text = thread_text

View File

@@ -90,16 +90,6 @@ def _build_index_filters(
if not source_filter and detected_source_filter:
source_filter = detected_source_filter
# CRITICAL FIX: If user_file_ids are present, we must ensure "user_file"
# source type is included in the filter, otherwise user files will be excluded!
if user_file_ids and source_filter:
from onyx.configs.constants import DocumentSource
# Add user_file to the source filter if not already present
if DocumentSource.USER_FILE not in source_filter:
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
logger.debug("Added USER_FILE to source_filter for user knowledge search")
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
@@ -114,7 +104,6 @@ def _build_index_filters(
access_control_list=user_acl_filters,
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
)
return final_filters

View File

@@ -44,7 +44,6 @@ def query_analysis(query: str) -> tuple[bool, list[str]]:
return analysis_model.predict(query)
# TODO: This is unused code.
@log_function_time(print_only=True)
def retrieval_preprocessing(
search_request: SearchRequest,

View File

@@ -118,7 +118,6 @@ def combine_retrieval_results(
return sorted_chunks
# TODO: This is unused code.
@log_function_time(print_only=True)
def doc_index_retrieval(
query: SearchQuery,
@@ -349,7 +348,6 @@ def retrieve_chunks(
list(query.filters.source_type) if query.filters.source_type else None,
query.filters.document_set,
slack_context,
query.filters.user_file_ids,
)
federated_sources = set(
federated_retrieval_info.source.to_non_federated_source()
@@ -477,7 +475,6 @@ def search_chunks(
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
slack_context=slack_context,
user_file_ids=query_request.filters.user_file_ids,
)
federated_sources = set(

View File

@@ -63,7 +63,7 @@ def get_live_users_count(db_session: Session) -> int:
This does NOT include invited users, "users" pulled in
from external connectors, or API keys.
"""
count_stmt = func.count(User.id)
count_stmt = func.count(User.id) # type: ignore
select_stmt = select(count_stmt)
select_stmt_w_filters = _add_live_user_count_where_clause(select_stmt, False)
user_count = db_session.scalar(select_stmt_w_filters)
@@ -74,7 +74,7 @@ def get_live_users_count(db_session: Session) -> int:
async def get_user_count(only_admin_users: bool = False) -> int:
async with get_async_session_context_manager() as session:
count_stmt = func.count(User.id)
count_stmt = func.count(User.id) # type: ignore
stmt = select(count_stmt)
stmt_w_filters = _add_live_user_count_where_clause(stmt, only_admin_users)
user_count = await session.scalar(stmt_w_filters)
@@ -100,10 +100,10 @@ class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def get_user_db(
session: AsyncSession = Depends(get_async_session),
) -> AsyncGenerator[SQLAlchemyUserAdminDB, None]:
yield SQLAlchemyUserAdminDB(session, User, OAuthAccount)
yield SQLAlchemyUserAdminDB(session, User, OAuthAccount) # type: ignore
async def get_access_token_db(
session: AsyncSession = Depends(get_async_session),
) -> AsyncGenerator[SQLAlchemyAccessTokenDatabase, None]:
yield SQLAlchemyAccessTokenDatabase(session, AccessToken)
yield SQLAlchemyAccessTokenDatabase(session, AccessToken) # type: ignore

View File

@@ -40,21 +40,6 @@ def check_connectors_exist(db_session: Session) -> bool:
return result.scalar() or False
def check_user_files_exist(db_session: Session) -> bool:
"""Check if any user files exist in the system.
This is used to determine if the search tool should be available
when there are no regular connectors but there are user files
(User Knowledge mode).
"""
from onyx.db.models import UserFile
from onyx.db.enums import UserFileStatus
stmt = select(exists(UserFile).where(UserFile.status == UserFileStatus.COMPLETED))
result = db_session.execute(stmt)
return result.scalar() or False
def fetch_connectors(
db_session: Session,
sources: list[DocumentSource] | None = None,

View File

@@ -290,7 +290,7 @@ def get_document_counts_for_cc_pairs(
)
)
for connector_id, credential_id, cnt in db_session.execute(stmt).all():
for connector_id, credential_id, cnt in db_session.execute(stmt).all(): # type: ignore
aggregated_counts[(connector_id, credential_id)] = cnt
# Convert aggregated results back to the expected sequence of tuples
@@ -1098,7 +1098,7 @@ def reset_all_document_kg_stages(db_session: Session) -> int:
# The hasattr check is needed for type checking, even though rowcount
# is guaranteed to exist at runtime for UPDATE operations
return result.rowcount if hasattr(result, "rowcount") else 0
return result.rowcount if hasattr(result, "rowcount") else 0 # type: ignore
def update_document_kg_stages(
@@ -1121,7 +1121,7 @@ def update_document_kg_stages(
result = db_session.execute(stmt)
# The hasattr check is needed for type checking, even though rowcount
# is guaranteed to exist at runtime for UPDATE operations
return result.rowcount if hasattr(result, "rowcount") else 0
return result.rowcount if hasattr(result, "rowcount") else 0 # type: ignore
def get_skipped_kg_documents(db_session: Session) -> list[str]:

View File

@@ -2959,7 +2959,7 @@ class SlackChannelConfig(Base):
"slack_bot_id",
"is_default",
unique=True,
postgresql_where=(is_default is True),
postgresql_where=(is_default is True), # type: ignore
),
)

View File

@@ -257,7 +257,7 @@ def _get_users_by_emails(
"""given a list of lowercase emails,
returns a list[User] of Users whose emails match and a list[str]
the missing emails that had no User"""
stmt = select(User).filter(func.lower(User.email).in_(lower_emails))
stmt = select(User).filter(func.lower(User.email).in_(lower_emails)) # type: ignore
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
# Extract found emails and convert to lowercase to avoid case sensitivity issues

View File

@@ -1,7 +1,6 @@
# TODO: Notes for potential extensions and future improvements:
# 1. Allow tools that aren't search specific tools
# 2. Use user provided custom prompts
# 3. Save the plan for replay
from collections.abc import Callable
from typing import cast
@@ -38,7 +37,6 @@ from onyx.prompts.deep_research.orchestration_layer import ORCHESTRATOR_PROMPT_R
from onyx.prompts.deep_research.orchestration_layer import RESEARCH_PLAN_PROMPT
from onyx.prompts.deep_research.orchestration_layer import USER_FINAL_REPORT_QUERY
from onyx.prompts.prompt_utils import get_current_llm_day_time
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import DeepResearchPlanDelta
@@ -47,9 +45,9 @@ from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.fake_tools.research_agent import run_research_agent_calls
from onyx.tools.interface import Tool
from onyx.tools.models import ToolCallInfo
from onyx.tools.models import ToolCallKickoff
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
@@ -110,7 +108,7 @@ def generate_final_report(
tool_definitions=[],
tool_choice=ToolChoiceOptions.NONE,
llm=llm,
placement=Placement(turn_index=999), # TODO
turn_index=999, # TODO
citation_processor=None,
state_container=state_container,
reasoning_effort=ReasoningEffort.LOW,
@@ -187,7 +185,7 @@ def run_deep_research_llm_loop(
tool_definitions=get_clarification_tool_definitions(),
tool_choice=ToolChoiceOptions.AUTO,
llm=llm,
placement=Placement(turn_index=0),
turn_index=0,
# No citations in this step, it should just pass through all
# tokens directly so initialized as an empty citation processor
citation_processor=None,
@@ -200,9 +198,7 @@ def run_deep_research_llm_loop(
# Mark this turn as a clarification question
state_container.set_is_clarification(True)
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=OverallStop(type="stop"))
)
emitter.emit(Packet(turn_index=0, obj=OverallStop(type="stop")))
# If a clarification is asked, we need to end this turn and wait on user input
return
@@ -233,7 +229,7 @@ def run_deep_research_llm_loop(
tool_definitions=[],
tool_choice=ToolChoiceOptions.NONE,
llm=llm,
placement=Placement(turn_index=0),
turn_index=0,
citation_processor=None,
state_container=state_container,
final_documents=None,
@@ -248,14 +244,14 @@ def run_deep_research_llm_loop(
if isinstance(packet.obj, AgentResponseStart):
emitter.emit(
Packet(
placement=packet.placement,
turn_index=packet.turn_index,
obj=DeepResearchPlanStart(),
)
)
elif isinstance(packet.obj, AgentResponseDelta):
emitter.emit(
Packet(
placement=packet.placement,
turn_index=packet.turn_index,
obj=DeepResearchPlanDelta(content=packet.obj.content),
)
)
@@ -269,7 +265,7 @@ def run_deep_research_llm_loop(
emitter.emit(
Packet(
# Marks the last turn end which should be the plan generation
placement=Placement(turn_index=orchestrator_start_turn_index - 1),
turn_index=orchestrator_start_turn_index - 1,
obj=SectionEnd(),
)
)
@@ -344,9 +340,7 @@ def run_deep_research_llm_loop(
),
tool_choice=ToolChoiceOptions.REQUIRED,
llm=llm,
placement=Placement(
turn_index=orchestrator_start_turn_index + cycle + reasoning_cycles
),
turn_index=orchestrator_start_turn_index + cycle + reasoning_cycles,
# No citations in this step, it should just pass through all
# tokens directly so initialized as an empty citation processor
citation_processor=DynamicCitationProcessor(),
@@ -366,8 +360,6 @@ def run_deep_research_llm_loop(
)
if not tool_calls:
# Basically hope that this is an infrequent occurence and hopefully multiple research
# cycles have already ran
logger.warning("No tool calls found, this should not happen.")
generate_final_report(
history=simple_chat_history,
@@ -398,8 +390,6 @@ def run_deep_research_llm_loop(
# This will not actually get saved to the db as a tool call but we'll attach it to the tool(s) called after
# it as if it were just a reasoning model doing it. In the chat history, because it happens in 2 steps,
# we will show it as a separate message.
# NOTE: This does not need to increment the reasoning cycles because the custom token processor causes
# the LLM step to handle this
most_recent_reasoning = state_container.reasoning_tokens
tool_call_message = think_tool_call.to_msg_str()
@@ -420,6 +410,7 @@ def run_deep_research_llm_loop(
image_files=None,
)
simple_chat_history.append(think_tool_response_msg)
reasoning_cycles += 1
continue
else:
for tool_call in tool_calls:
@@ -444,7 +435,6 @@ def run_deep_research_llm_loop(
break
research_results = run_research_agent_calls(
# The tool calls here contain the placement information
research_agent_calls=research_agent_calls,
tools=allowed_tools,
emitter=emitter,

View File

@@ -90,8 +90,6 @@ class MetadataUpdateRequest(BaseModel):
the contents of the document.
"""
model_config = {"frozen": True}
document_ids: list[str]
# Passed in to help with potential optimizations of the implementation. The
# keys should be redundant with document_ids.

View File

@@ -273,7 +273,6 @@ schema {{ schema_name }} {
# Boost based on aggregated boost calculation
* aggregated_chunk_boost
}
# Target hits for hybrid retrieval should be at least this value.
rerank-count: 1000
}
@@ -342,7 +341,6 @@ schema {{ schema_name }} {
# Boost based on aggregated boost calculation
* aggregated_chunk_boost
}
# Target hits for hybrid retrieval should be at least this value.
rerank-count: 1000
}

View File

@@ -15,19 +15,19 @@ from typing import cast
from typing import List
from uuid import UUID
import httpx
import httpx # type: ignore
import jinja2
import requests
import requests # type: ignore
from pydantic import BaseModel
from retry import retry
from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
from onyx.configs.chat_configs import VESPA_SEARCHER_THREADS
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.context.search.enums import QueryType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceChunkUncleaned
@@ -88,6 +88,7 @@ from onyx.utils.timing import log_function_time
from shared_configs.configs import MULTI_TENANT
from shared_configs.model_server_models import Embedding
logger = setup_logger()
# Set the logging level to WARNING to ignore INFO and DEBUG logs
@@ -957,37 +958,48 @@ class VespaIndex(DocumentIndex):
offset: int = 0,
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
) -> list[InferenceChunk]:
tenant_id = filters.tenant_id if filters.tenant_id is not None else ""
vespa_document_index = VespaDocumentIndex(
index_name=self.index_name,
tenant_state=TenantState(
tenant_id=tenant_id,
multitenant=self.multitenant,
vespa_where_clauses = build_vespa_filters(filters)
# Needs to be at least as much as the value set in Vespa schema config
target_hits = max(10 * num_to_retrieve, 1000)
yql = (
YQL_BASE.format(index_name=self.index_name)
+ vespa_where_clauses
+ f"(({{targetHits: {target_hits}}}nearestNeighbor(embeddings, query_embedding)) "
+ f"or ({{targetHits: {target_hits}}}nearestNeighbor(title_embedding, query_embedding)) "
+ 'or ({grammar: "weakAnd"}userInput(@query)) '
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
)
final_query = " ".join(final_keywords) if final_keywords else query
if ranking_profile_type == QueryExpansionType.KEYWORD:
ranking_profile = f"hybrid_search_keyword_base_{len(query_embedding)}"
else:
ranking_profile = f"hybrid_search_semantic_base_{len(query_embedding)}"
logger.info(f"Selected ranking profile: {ranking_profile}")
logger.debug(f"Query YQL: {yql}")
params: dict[str, str | int | float] = {
"yql": yql,
"query": final_query,
"input.query(query_embedding)": str(query_embedding),
"input.query(decay_factor)": str(DOC_TIME_DECAY * time_decay_multiplier),
"input.query(alpha)": hybrid_alpha,
"input.query(title_content_ratio)": (
title_content_ratio
if title_content_ratio is not None
else TITLE_CONTENT_RATIO
),
large_chunks_enabled=self.large_chunks_enabled,
httpx_client=self.httpx_client,
)
if not (
ranking_profile_type == QueryExpansionType.KEYWORD
or ranking_profile_type == QueryExpansionType.SEMANTIC
):
raise ValueError(
f"Bug: Received invalid ranking profile type: {ranking_profile_type}"
)
query_type = (
QueryType.KEYWORD
if ranking_profile_type == QueryExpansionType.KEYWORD
else QueryType.SEMANTIC
)
return vespa_document_index.hybrid_retrieval(
query,
query_embedding,
final_keywords,
query_type,
filters,
num_to_retrieve,
offset,
)
"hits": num_to_retrieve,
"offset": offset,
"ranking.profile": ranking_profile,
"timeout": VESPA_TIMEOUT,
}
return cleanup_chunks(query_vespa(params))
def admin_retrieval(
self,

View File

@@ -1,22 +1,11 @@
import concurrent.futures
import logging
from uuid import UUID
import httpx
from pydantic import BaseModel
from retry import retry
from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
from onyx.configs.app_configs import RERANK_COUNT
from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.context.search.enums import QueryType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.document_index_utils import get_document_chunk_ids
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
@@ -26,7 +15,6 @@ from onyx.document_index.interfaces_new import DocumentInsertionRecord
from onyx.document_index.interfaces_new import DocumentSectionRequest
from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import MetadataUpdateRequest
from onyx.document_index.vespa.chunk_retrieval import query_vespa
from onyx.document_index.vespa.deletion import delete_vespa_chunks
from onyx.document_index.vespa.indexing_utils import BaseHTTPXClientContext
from onyx.document_index.vespa.indexing_utils import batch_index_vespa_chunks
@@ -35,32 +23,13 @@ from onyx.document_index.vespa.indexing_utils import clean_chunk_id_copy
from onyx.document_index.vespa.indexing_utils import GlobalHTTPXClientContext
from onyx.document_index.vespa.indexing_utils import TemporaryHTTPXClientContext
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
build_vespa_filters,
)
from onyx.document_index.vespa_constants import BATCH_SIZE
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
from onyx.document_index.vespa_constants import YQL_BASE
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.tools.tool_implementations.search.constants import KEYWORD_QUERY_HYBRID_ALPHA
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
from shared_configs.model_server_models import Embedding
logger = setup_logger()
# Set the logging level to WARNING to ignore INFO and DEBUG logs from httpx. By
# default it emits INFO-level logs for every request.
httpx_logger = logging.getLogger("httpx")
httpx_logger.setLevel(logging.WARNING)
class TenantState(BaseModel):
"""
Captures the tenant-related state for an instance of VespaDocumentIndex.
@@ -133,211 +102,6 @@ def _enrich_basic_chunk_info(
return enriched_doc_info
def _cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk]:
"""Removes indexing-time content additions from chunks retrieved from Vespa.
During indexing, chunks are augmented with additional text to improve search
quality:
- Title prepended to content (for better keyword/semantic matching)
- Metadata suffix appended to content
- Contextual RAG: doc_summary (beginning) and chunk_context (end)
This function strips these additions before returning chunks to users,
restoring the original document content. Cleaning is applied in sequence:
1. Title removal:
- Full match: Strips exact title from beginning
- Partial match: If content starts with title[:BLURB_SIZE], splits on
RETURN_SEPARATOR to remove title section
2. Metadata suffix removal:
- Strips metadata_suffix from end, plus trailing RETURN_SEPARATOR
3. Contextual RAG removal:
- Strips doc_summary from beginning (if present)
- Strips chunk_context from end (if present)
Args:
chunks: Chunks as retrieved from Vespa with indexing augmentations
intact.
Returns:
Clean InferenceChunk objects with augmentations removed, containing only
the original document content that should be shown to users.
"""
def _remove_title(chunk: InferenceChunkUncleaned) -> str:
if not chunk.title or not chunk.content:
return chunk.content
if chunk.content.startswith(chunk.title):
return chunk.content[len(chunk.title) :].lstrip()
# BLURB SIZE is by token instead of char but each token is at least 1 char
# If this prefix matches the content, it's assumed the title was prepended
if chunk.content.startswith(chunk.title[:BLURB_SIZE]):
return (
chunk.content.split(RETURN_SEPARATOR, 1)[-1]
if RETURN_SEPARATOR in chunk.content
else chunk.content
)
return chunk.content
def _remove_metadata_suffix(chunk: InferenceChunkUncleaned) -> str:
if not chunk.metadata_suffix:
return chunk.content
return chunk.content.removesuffix(chunk.metadata_suffix).rstrip(
RETURN_SEPARATOR
)
def _remove_contextual_rag(chunk: InferenceChunkUncleaned) -> str:
# remove document summary
if chunk.doc_summary and chunk.content.startswith(chunk.doc_summary):
chunk.content = chunk.content[len(chunk.doc_summary) :].lstrip()
# remove chunk context
if chunk.chunk_context and chunk.content.endswith(chunk.chunk_context):
chunk.content = chunk.content[
: len(chunk.content) - len(chunk.chunk_context)
].rstrip()
return chunk.content
for chunk in chunks:
chunk.content = _remove_title(chunk)
chunk.content = _remove_metadata_suffix(chunk)
chunk.content = _remove_contextual_rag(chunk)
return [chunk.to_inference_chunk() for chunk in chunks]
@retry(
tries=3,
delay=1,
backoff=2,
exceptions=httpx.HTTPError,
)
def _update_single_chunk(
doc_chunk_id: UUID,
index_name: str,
doc_id: str,
http_client: httpx.Client,
update_request: MetadataUpdateRequest,
) -> None:
"""Updates a single document chunk in Vespa.
TODO(andrei): Couldn't this be batched?
Args:
doc_chunk_id: The ID of the chunk to update.
index_name: The index the chunk belongs to.
doc_id: The ID of the document the chunk belongs to.
http_client: The HTTP client to use to make the request.
update_request: Metadata update request object received in the bulk
update method containing fields to update.
"""
class _Boost(BaseModel):
model_config = {"frozen": True}
assign: float
class _DocumentSets(BaseModel):
model_config = {"frozen": True}
assign: dict[str, int]
class _AccessControl(BaseModel):
model_config = {"frozen": True}
assign: dict[str, int]
class _Hidden(BaseModel):
model_config = {"frozen": True}
assign: bool
class _UserProjects(BaseModel):
model_config = {"frozen": True}
assign: list[int]
class _VespaPutFields(BaseModel):
model_config = {"frozen": True}
# The names of these fields are based the Vespa schema. Changes to the
# schema require changes here. These names were originally found in
# backend/onyx/document_index/vespa_constants.py.
boost: _Boost | None = None
document_sets: _DocumentSets | None = None
access_control_list: _AccessControl | None = None
hidden: _Hidden | None = None
user_project: _UserProjects | None = None
class _VespaPutRequest(BaseModel):
model_config = {"frozen": True}
fields: _VespaPutFields
boost_update: _Boost | None = (
_Boost(assign=update_request.boost)
if update_request.boost is not None
else None
)
document_sets_update: _DocumentSets | None = (
_DocumentSets(
assign={document_set: 1 for document_set in update_request.document_sets}
)
if update_request.document_sets is not None
else None
)
access_update: _AccessControl | None = (
_AccessControl(
assign={acl_entry: 1 for acl_entry in update_request.access.to_acl()}
)
if update_request.access is not None
else None
)
hidden_update: _Hidden | None = (
_Hidden(assign=update_request.hidden)
if update_request.hidden is not None
else None
)
user_projects_update: _UserProjects | None = (
_UserProjects(assign=list(update_request.project_ids))
if update_request.project_ids is not None
else None
)
vespa_put_fields = _VespaPutFields(
boost=boost_update,
document_sets=document_sets_update,
access_control_list=access_update,
hidden=hidden_update,
user_project=user_projects_update,
)
vespa_put_request = _VespaPutRequest(
fields=vespa_put_fields,
)
vespa_url = (
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}"
"?create=true"
)
try:
resp = http_client.put(
vespa_url,
headers={"Content-Type": "application/json"},
json=vespa_put_request.model_dump(
exclude_none=True
), # NOTE: Important to not produce null fields in the json.
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(
f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). "
f"Code: {e.response.status_code}. Details: {e.response.text}"
)
# Re-raise so the @retry decorator will catch and retry, unless the
# status code is < 5xx, in which case wrap the exception in something
# other than an HTTPError to skip retries.
if e.response.status_code >= 500:
raise
raise RuntimeError(
f"Non-retryable error updating chunk {doc_chunk_id}: {e}"
) from e
class VespaDocumentIndex(DocumentIndex):
"""Vespa-specific implementation of the DocumentIndex interface.
@@ -371,10 +135,6 @@ class VespaDocumentIndex(DocumentIndex):
get_vespa_http_client
)
self._multitenant = tenant_state.multitenant
if self._multitenant:
assert (
self._tenant_id
), "Bug: Must supply a tenant id if in multitenant mode."
def verify_and_create_index_if_necessary(
self, embedding_dim: int, embedding_precision: EmbeddingPrecision
@@ -499,35 +259,7 @@ class VespaDocumentIndex(DocumentIndex):
raise NotImplementedError
def update(self, update_requests: list[MetadataUpdateRequest]) -> None:
with self._httpx_client_context as httpx_client:
# Each invocation of this method can contain multiple update requests.
for update_request in update_requests:
# Each update request can correspond to multiple documents.
for doc_id in update_request.document_ids:
chunk_count = update_request.doc_id_to_chunk_cnt[doc_id]
sanitized_doc_id = replace_invalid_doc_id_characters(doc_id)
enriched_doc_info = _enrich_basic_chunk_info(
index_name=self._index_name,
http_client=httpx_client,
document_id=sanitized_doc_id,
previous_chunk_count=chunk_count,
new_chunk_count=0, # WARNING: This semantically makes no sense and is misusing this function.
)
doc_chunk_ids = get_document_chunk_ids(
enriched_document_info_list=[enriched_doc_info],
tenant_id=self._tenant_id,
large_chunks_enabled=self._large_chunks_enabled,
)
for doc_chunk_id in doc_chunk_ids:
_update_single_chunk(
doc_chunk_id,
self._index_name,
doc_id,
httpx_client,
update_request,
)
raise NotImplementedError
def id_based_retrieval(
self, chunk_requests: list[DocumentSectionRequest]
@@ -544,56 +276,7 @@ class VespaDocumentIndex(DocumentIndex):
num_to_retrieve: int,
offset: int = 0,
) -> list[InferenceChunk]:
vespa_where_clauses = build_vespa_filters(filters)
# Needs to be at least as much as the rerank-count value set in the
# Vespa schema config. Otherwise we would be getting fewer results than
# expected for reranking.
target_hits = max(10 * num_to_retrieve, RERANK_COUNT)
yql = (
YQL_BASE.format(index_name=self._index_name)
+ vespa_where_clauses
+ f"(({{targetHits: {target_hits}}}nearestNeighbor(embeddings, query_embedding)) "
+ f"or ({{targetHits: {target_hits}}}nearestNeighbor(title_embedding, query_embedding)) "
+ 'or ({grammar: "weakAnd"}userInput(@query)) '
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
)
final_query = " ".join(final_keywords) if final_keywords else query
ranking_profile = (
f"hybrid_search_{query_type.value}_base_{len(query_embedding)}"
)
logger.info(f"Selected ranking profile: {ranking_profile}")
logger.debug(f"Query YQL: {yql}")
# In this interface we do not pass in hybrid alpha. Tracing the codepath
# of the legacy Vespa interface, it so happens that KEYWORD always
# corresponds to an alpha of 0.2 (from KEYWORD_QUERY_HYBRID_ALPHA), and
# SEMANTIC to 0.5 (from HYBRID_ALPHA). HYBRID_ALPHA_KEYWORD was only
# used in dead code so we do not use it here.
hybrid_alpha = (
KEYWORD_QUERY_HYBRID_ALPHA
if query_type == QueryType.KEYWORD
else HYBRID_ALPHA
)
params: dict[str, str | int | float] = {
"yql": yql,
"query": final_query,
"input.query(query_embedding)": str(query_embedding),
"input.query(decay_factor)": str(DOC_TIME_DECAY * RECENCY_BIAS_MULTIPLIER),
"input.query(alpha)": hybrid_alpha,
"input.query(title_content_ratio)": TITLE_CONTENT_RATIO,
"hits": num_to_retrieve,
"offset": offset,
"ranking.profile": ranking_profile,
"timeout": VESPA_TIMEOUT,
}
return _cleanup_chunks(query_vespa(params))
raise NotImplementedError
def random_retrieval(
self,

View File

@@ -34,7 +34,7 @@ class BraintrustEvalProvider(EvalProvider):
eval_data = [EvalCase(input=item["input"]) for item in data]
Eval(
name=BRAINTRUST_PROJECT,
data=eval_data,
data=eval_data, # type: ignore[arg-type]
task=task,
scores=[],
metadata={**configuration.model_dump()},

View File

@@ -38,18 +38,7 @@ def get_federated_retrieval_functions(
source_types: list[DocumentSource] | None,
document_set_names: list[str] | None,
slack_context: SlackContext | None = None,
user_file_ids: list[UUID] | None = None,
) -> list[FederatedRetrievalInfo]:
# When User Knowledge (user files) is the only knowledge source enabled,
# skip federated connectors entirely. User Knowledge mode means the agent
# should ONLY use uploaded files, not team connectors like Slack.
if user_file_ids and not document_set_names:
logger.debug(
"Skipping all federated connectors: User Knowledge mode enabled "
f"with {len(user_file_ids)} user files and no document sets"
)
return []
# Check for Slack bot context first (regardless of user_id)
if slack_context:
logger.debug("Slack context detected, checking for Slack bot setup...")

View File

@@ -47,10 +47,10 @@ class SlackEntities(BaseModel):
# Message count per slack request
max_messages_per_query: int = Field(
default=10,
default=25,
description=(
"Maximum number of messages to retrieve per search query. "
"Higher values increase API calls and may trigger rate limits."
"Higher values provide more context but may be slower."
),
)

View File

@@ -36,7 +36,7 @@ def delete_unstructured_api_key() -> None:
def _sdk_partition_request(
file: IO[Any], file_name: str, **kwargs: Any
) -> "operations.PartitionRequest":
from unstructured_client.models import operations
from unstructured_client.models import operations # type: ignore
from unstructured_client.models import shared
file.seek(0, 0)
@@ -62,7 +62,7 @@ def unstructured_to_text(file: IO[Any], file_name: str) -> str:
unstructured_client = UnstructuredClient(api_key_auth=get_unstructured_api_key())
response = unstructured_client.general.partition(req)
response = unstructured_client.general.partition(req) # type: ignore
elements = dict_to_elements(response.elements)
if response.status_code != 200:

View File

@@ -50,10 +50,10 @@ from onyx.indexing.models import IndexChunk
from onyx.indexing.models import IndexingBatchAdapter
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.factory import get_llm_for_contextual_rag
from onyx.llm.interfaces import LLM
from onyx.llm.multi_llm import LLMRateLimitError
from onyx.llm.utils import llm_response_to_string
from onyx.llm.utils import MAX_CONTEXT_TOKENS
from onyx.natural_language_processing.search_nlp_models import (

View File

@@ -45,7 +45,9 @@ class PgRedisKVStore(KeyValueStore):
obj.value = plain_val
obj.encrypted_value = encrypted_val
else:
obj = KVStore(key=key, value=plain_val, encrypted_value=encrypted_val)
obj = KVStore(
key=key, value=plain_val, encrypted_value=encrypted_val
) # type: ignore
db_session.query(KVStore).filter_by(key=key).delete() # just in case
db_session.add(obj)
db_session.commit()
@@ -95,7 +97,7 @@ class PgRedisKVStore(KeyValueStore):
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
with get_session_with_current_tenant() as db_session:
result = db_session.query(KVStore).filter_by(key=key).delete()
result = db_session.query(KVStore).filter_by(key=key).delete() # type: ignore
if result == 0:
raise KvKeyNotFoundError
db_session.commit()

View File

@@ -14,12 +14,12 @@ from onyx.db.llm import fetch_llm_provider_view
from onyx.db.llm import fetch_user_group_ids
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.llm.chat_llm import LitellmLLM
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.llm.llm_provider_options import OLLAMA_API_KEY_CONFIG_KEY
from onyx.llm.llm_provider_options import OLLAMA_PROVIDER_NAME
from onyx.llm.llm_provider_options import OPENROUTER_PROVIDER_NAME
from onyx.llm.multi_llm import LitellmLLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
from onyx.llm.utils import model_supports_image_input

View File

@@ -152,7 +152,7 @@ def litellm_exception_to_error_msg(
if message_attr:
upstream_detail = str(message_attr)
elif hasattr(core_exception, "api_error"):
api_error = core_exception.api_error
api_error = core_exception.api_error # type: ignore[attr-defined]
if isinstance(api_error, dict):
upstream_detail = (
api_error.get("message")

View File

@@ -15,10 +15,10 @@ from typing import cast
import aioboto3 # type: ignore
import httpx
import requests
import voyageai # type: ignore[import-untyped]
import voyageai # type: ignore
from cohere import AsyncClient as CohereAsyncClient
from cohere.core.api_error import ApiError
from google.oauth2 import service_account
from google.oauth2 import service_account # type: ignore
from httpx import HTTPError
from requests import JSONDecodeError
from requests import RequestException
@@ -89,6 +89,30 @@ _AUTH_ERROR_UNAUTHORIZED = "unauthorized"
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
_AUTH_ERROR_PERMISSION = "permission"
# Thread-local storage for event loops
# This prevents creating thousands of event loops during batch processing,
# which was causing severe memory leaks with API-based embedding providers
_thread_local = threading.local()
def _get_or_create_event_loop() -> asyncio.AbstractEventLoop:
"""Get or create a thread-local event loop for API embedding calls.
This prevents creating a new event loop for every batch during embedding,
which was causing memory leaks. Instead, each thread reuses the same loop.
Returns:
asyncio.AbstractEventLoop: The thread-local event loop
"""
if (
not hasattr(_thread_local, "loop")
or _thread_local.loop is None
or _thread_local.loop.is_closed()
):
_thread_local.loop = asyncio.new_event_loop()
asyncio.set_event_loop(_thread_local.loop)
return _thread_local.loop
WARM_UP_STRINGS = [
"Onyx is amazing!",
@@ -271,8 +295,8 @@ class CloudEmbedding:
embedding_type: str,
reduced_dimension: int | None,
) -> list[Embedding]:
from google import genai
from google.genai import types as genai_types
from google import genai # type: ignore[import-untyped]
from google.genai import types as genai_types # type: ignore[import-untyped]
if not model:
model = DEFAULT_VERTEX_MODEL
@@ -776,16 +800,14 @@ class EmbeddingModel:
# Route between direct API calls and model server calls
if self.provider_type is not None:
# For API providers, make direct API call
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
response = loop.run_until_complete(
self._make_direct_api_call(
embed_request, tenant_id=tenant_id, request_id=request_id
)
# Use thread-local event loop to prevent memory leaks from creating
# thousands of event loops during batch processing
loop = _get_or_create_event_loop()
response = loop.run_until_complete(
self._make_direct_api_call(
embed_request, tenant_id=tenant_id, request_id=request_id
)
finally:
loop.close()
)
else:
# For local models, use model server
response = self._make_model_server_request(

View File

@@ -3,8 +3,8 @@ from abc import ABC
from abc import abstractmethod
from copy import copy
from tokenizers import Encoding # type: ignore[import-untyped]
from tokenizers import Tokenizer
from tokenizers import Encoding # type: ignore
from tokenizers import Tokenizer # type: ignore
from onyx.configs.model_configs import DOCUMENT_ENCODER_MODEL
from onyx.context.search.models import InferenceChunk

View File

@@ -1,5 +1,5 @@
from mistune import Markdown # type: ignore[import-untyped]
from mistune import Renderer
from mistune import Markdown # type: ignore
from mistune import Renderer # type: ignore
def format_slack_message(message: str | None) -> str:

View File

@@ -452,7 +452,7 @@ def redis_lock_dump(lock: RedisLock, r: Redis) -> None:
ttl = r.ttl(name)
locked = lock.locked()
owned = lock.owned()
local_token: str | None = lock.local.token
local_token: str | None = lock.local.token # type: ignore
remote_token_raw = r.get(lock.name)
if remote_token_raw:

View File

@@ -16,7 +16,7 @@ from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi import UploadFile
from google.oauth2.credentials import Credentials
from google.oauth2.credentials import Credentials # type: ignore
from pydantic import BaseModel
from sqlalchemy.orm import Session

View File

@@ -1,10 +0,0 @@
from pydantic import BaseModel
class Placement(BaseModel):
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
turn_index: int
# For parallel tool calls to preserve order of execution
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None

View File

@@ -11,7 +11,6 @@ from onyx.db.chat import get_db_search_doc_by_id
from onyx.db.chat import translate_db_search_doc_to_saved_search_doc
from onyx.db.models import ChatMessage
from onyx.db.tools import get_tool_by_id
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
@@ -60,7 +59,7 @@ def create_message_packets(
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=AgentResponseStart(
final_documents=final_search_docs,
),
@@ -69,7 +68,7 @@ def create_message_packets(
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=AgentResponseDelta(
content=message_text,
),
@@ -78,7 +77,7 @@ def create_message_packets(
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=SectionEnd(),
)
)
@@ -95,12 +94,12 @@ def create_citation_packets(
for citation_info in citation_info_list:
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=citation_info,
)
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
return packets
@@ -108,20 +107,18 @@ def create_citation_packets(
def create_reasoning_packets(reasoning_text: str, turn_index: int) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(placement=Placement(turn_index=turn_index), obj=ReasoningStart())
)
packets.append(Packet(turn_index=turn_index, obj=ReasoningStart()))
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=ReasoningDelta(
reasoning=reasoning_text,
),
),
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
return packets
@@ -133,24 +130,21 @@ def create_image_generation_packets(
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
turn_index=turn_index,
tab_index=tab_index,
obj=ImageGenerationToolStart(),
)
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
turn_index=turn_index,
tab_index=tab_index,
obj=ImageGenerationFinal(images=images),
),
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
obj=SectionEnd(),
)
)
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
return packets
@@ -167,14 +161,16 @@ def create_custom_tool_packets(
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
turn_index=turn_index,
tab_index=tab_index,
obj=CustomToolStart(tool_name=tool_name),
)
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
turn_index=turn_index,
tab_index=tab_index,
obj=CustomToolDelta(
tool_name=tool_name,
response_type=response_type,
@@ -184,12 +180,7 @@ def create_custom_tool_packets(
),
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
obj=SectionEnd(),
)
)
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
return packets
@@ -204,32 +195,30 @@ def create_fetch_packets(
# Emit start packet
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
turn_index=turn_index,
tab_index=tab_index,
obj=OpenUrlStart(),
)
)
# Emit URLs packet
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
turn_index=turn_index,
tab_index=tab_index,
obj=OpenUrlUrls(urls=urls),
)
)
# Emit documents packet
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
turn_index=turn_index,
tab_index=tab_index,
obj=OpenUrlDocuments(
documents=[SearchDoc(**doc.model_dump()) for doc in fetch_docs]
),
)
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
obj=SectionEnd(),
)
)
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
return packets
@@ -244,7 +233,8 @@ def create_search_packets(
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
turn_index=turn_index,
tab_index=tab_index,
obj=SearchToolStart(
is_internet_search=is_internet_search,
),
@@ -255,7 +245,8 @@ def create_search_packets(
if search_queries:
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
turn_index=turn_index,
tab_index=tab_index,
obj=SearchToolQueriesDelta(queries=search_queries),
),
)
@@ -267,7 +258,8 @@ def create_search_packets(
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
turn_index=turn_index,
tab_index=tab_index,
obj=SearchToolDocumentsDelta(
documents=[
SearchDoc(**doc.model_dump()) for doc in sorted_search_docs
@@ -276,12 +268,7 @@ def create_search_packets(
),
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
obj=SectionEnd(),
)
)
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
return packets
@@ -450,8 +437,6 @@ def translate_assistant_message_to_packets(
)
# Add overall stop packet at the end
packet_list.append(
Packet(placement=Placement(turn_index=final_turn_index), obj=OverallStop())
)
packet_list.append(Packet(turn_index=final_turn_index, obj=OverallStop()))
return packet_list

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel
from pydantic import Field
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
class StreamingType(Enum):
@@ -289,7 +288,20 @@ PacketObj = Union[
]
class Placement(BaseModel):
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
turn_index: int
# For parallel tool calls to preserve order of execution
tab_index: int
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int
class Packet(BaseModel):
placement: Placement
turn_index: int | None
# For parallel tool calls to preserve order of execution
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
obj: Annotated[PacketObj, Field(discriminator="type")]

View File

@@ -4,7 +4,6 @@ import re
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
@@ -63,7 +62,7 @@ def create_message_packets(
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=AgentResponseStart(
final_documents=SearchDoc.from_saved_search_docs(final_documents or []),
),
@@ -84,7 +83,7 @@ def create_message_packets(
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=AgentResponseDelta(
content=adjusted_message_text,
),
@@ -93,7 +92,7 @@ def create_message_packets(
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=SectionEnd(),
)
)
@@ -110,12 +109,12 @@ def create_citation_packets(
for citation_info in citation_info_list:
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=citation_info,
)
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
return packets
@@ -123,20 +122,18 @@ def create_citation_packets(
def create_reasoning_packets(reasoning_text: str, turn_index: int) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(placement=Placement(turn_index=turn_index), obj=ReasoningStart())
)
packets.append(Packet(turn_index=turn_index, obj=ReasoningStart()))
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=ReasoningDelta(
reasoning=reasoning_text,
),
),
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
return packets
@@ -148,19 +145,19 @@ def create_image_generation_packets(
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=ImageGenerationToolStart(),
)
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=ImageGenerationFinal(images=images),
),
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
return packets
@@ -176,14 +173,14 @@ def create_custom_tool_packets(
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=CustomToolStart(tool_name=tool_name),
)
)
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=CustomToolDelta(
tool_name=tool_name,
response_type=response_type,
@@ -193,7 +190,7 @@ def create_custom_tool_packets(
),
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
return packets
@@ -207,27 +204,27 @@ def create_fetch_packets(
# Emit start packet
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=OpenUrlStart(),
)
)
# Emit URLs packet
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=OpenUrlUrls(urls=urls),
)
)
# Emit documents packet
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=OpenUrlDocuments(
documents=SearchDoc.from_saved_search_docs(fetch_docs)
),
)
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
return packets
@@ -241,7 +238,7 @@ def create_search_packets(
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=SearchToolStart(
is_internet_search=is_internet_search,
),
@@ -252,7 +249,7 @@ def create_search_packets(
if search_queries:
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=SearchToolQueriesDelta(queries=search_queries),
),
)
@@ -261,13 +258,13 @@ def create_search_packets(
if saved_search_docs:
packets.append(
Packet(
placement=Placement(turn_index=turn_index),
turn_index=turn_index,
obj=SearchToolDocumentsDelta(
documents=SearchDoc.from_saved_search_docs(saved_search_docs)
),
),
)
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
return packets

View File

@@ -34,13 +34,13 @@ from onyx.prompts.deep_research.research_agent import RESEARCH_REPORT_PROMPT
from onyx.prompts.deep_research.research_agent import USER_REPORT_QUERY
from onyx.prompts.prompt_utils import get_current_llm_day_time
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import Placement
from onyx.server.query_and_chat.streaming_models import ResearchAgentStart
from onyx.tools.interface import Tool
from onyx.tools.models import ToolCallInfo
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
@@ -93,7 +93,7 @@ def generate_intermediate_report(
tool_definitions=[],
tool_choice=ToolChoiceOptions.NONE,
llm=llm,
placement=placement,
turn_index=999, # TODO
citation_processor=DynamicCitationProcessor(),
state_container=state_container,
reasoning_effort=ReasoningEffort.LOW,
@@ -127,22 +127,24 @@ def run_research_agent_call(
token_counter: Callable[[str], int],
user_identity: LLMUserIdentity | None,
) -> ResearchAgentCallResult:
research_cycle_count = 0
cycle_count = 0
llm_cycle_count = 0
current_tools = tools
gathered_documents: list[SearchDoc] | None = None
reasoning_cycles = 0
just_ran_web_search = False
turn_index = research_agent_call.placement.turn_index
tab_index = research_agent_call.placement.tab_index
turn_index = research_agent_call.turn_index
tab_index = research_agent_call.tab_index
# If this fails to parse, we can't run the loop anyway, let this one fail in that case
research_topic = research_agent_call.tool_args[RESEARCH_AGENT_TASK_KEY]
emitter.emit(
Packet(
placement=Placement(turn_index=turn_index, tab_index=tab_index),
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=0,
obj=ResearchAgentStart(research_task=research_topic),
)
)
@@ -155,9 +157,8 @@ def run_research_agent_call(
msg_history: list[ChatMessageSimple] = [initial_user_message]
citation_mapping: dict[int, str] = {}
while research_cycle_count <= RESEARCH_CYCLE_CAP:
if research_cycle_count == RESEARCH_CYCLE_CAP:
# For the last cycle, do not use any more searches, only reason or generate a report
while cycle_count <= RESEARCH_CYCLE_CAP:
if cycle_count == RESEARCH_CYCLE_CAP:
current_tools = [
tool
for tool in tools
@@ -194,7 +195,7 @@ def run_research_agent_call(
system_prompt_str = system_prompt_template.format(
available_tools=tools_description,
current_datetime=get_current_llm_day_time(full_sentence=False),
current_cycle_count=research_cycle_count,
current_cycle_count=cycle_count,
optional_internal_search_tool_description=internal_search_tip,
optional_web_search_tool_description=web_search_tip,
optional_open_urls_tool_description=open_urls_tip,
@@ -234,11 +235,7 @@ def run_research_agent_call(
+ research_agent_tools,
tool_choice=ToolChoiceOptions.REQUIRED,
llm=llm,
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=llm_cycle_count + reasoning_cycles,
),
turn_index=llm_cycle_count + reasoning_cycles,
citation_processor=DynamicCitationProcessor(),
state_container=state_container,
reasoning_effort=ReasoningEffort.LOW,
@@ -253,13 +250,6 @@ def run_research_agent_call(
just_ran_web_search = False
if any(
tool_call.tool_name in {SearchTool.NAME, WebSearchTool.NAME}
for tool_call in tool_calls
):
# Only the search actions increment the cycle for the max cycle count
research_cycle_count += 1
special_tool_calls = check_special_tool_calls(tool_calls=tool_calls)
if special_tool_calls.generate_report_tool_call:
final_report = generate_intermediate_report(
@@ -271,10 +261,8 @@ def run_research_agent_call(
state_container=state_container,
emitter=emitter,
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=llm_cycle_count + reasoning_cycles,
),
turn_index=turn_index, tab_index=tab_index, sub_turn_index=0
), # TODO
)
return ResearchAgentCallResult(
intermediate_report=final_report, search_docs=[]
@@ -328,7 +316,7 @@ def run_research_agent_call(
raise ValueError("Tool response missing tool_call reference")
tool_call = tool_response.tool_call
tab_index = tool_call.placement.tab_index
tab_index = tool_call.tab_index
tool = tools_by_name.get(tool_call.tool_name)
if not tool:
@@ -404,10 +392,8 @@ def run_research_agent_call(
state_container=state_container,
emitter=emitter,
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=llm_cycle_count + reasoning_cycles,
),
turn_index=turn_index, tab_index=tab_index, sub_turn_index=0
), # TODO
)
return ResearchAgentCallResult(intermediate_report=final_report, search_docs=[])

View File

@@ -15,7 +15,6 @@ from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchDocsResponse
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.tools.tool_implementations.images.models import FinalImageGenerationResponse
@@ -39,7 +38,8 @@ class ToolCallKickoff(BaseModel):
tool_name: str
tool_args: dict[str, Any]
placement: Placement
turn_index: int
tab_index: int
def to_msg_str(self) -> str:
return json.dumps(

View File

@@ -3,13 +3,15 @@ from __future__ import annotations
import abc
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from sqlalchemy.orm import Session
from onyx.chat.emitter import Emitter
from onyx.server.query_and_chat.placement import Placement
from onyx.tools.models import ToolResponse
if TYPE_CHECKING:
from sqlalchemy.orm import Session
from onyx.tools.models import ToolResponse
TOverride = TypeVar("TOverride")
@@ -71,7 +73,7 @@ class Tool(abc.ABC, Generic[TOverride]):
raise NotImplementedError
@abc.abstractmethod
def emit_start(self, placement: Placement) -> None:
def emit_start(self, turn_index: int, tab_index: int) -> None:
"""
Emit the start packet for this tool. Each tool implementation should
emit its specific start packet type.
@@ -85,7 +87,11 @@ class Tool(abc.ABC, Generic[TOverride]):
@abc.abstractmethod
def run(
self,
placement: Placement,
# The run must know its turn because the "Tool" may actually be more of an "Agent" which can call
# other tools and must pass in this information potentially deeper down.
turn_index: int,
# Tab index for parallel tool calls (default 0 for single tool calls)
tab_index: int,
# Specific tool override arguments that are not provided by the LLM
# For example when calling the internal search tool, the original user query is passed along too (but not by the LLM)
override_kwargs: TOverride,

View File

@@ -30,8 +30,8 @@ from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.onyxbot.slack.models import SlackContext
from onyx.tools.built_in_tools import get_built_in_tool_by_id
from onyx.tools.interface import Tool
from onyx.tools.models import DynamicSchemaInfo
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)

View File

@@ -14,17 +14,16 @@ from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.configs.constants import FileOrigin
from onyx.file_store.file_store import get_default_file_store
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.interface import Tool
from onyx.tools.models import CHAT_SESSION_ID_PLACEHOLDER
from onyx.tools.models import CustomToolCallSummary
from onyx.tools.models import CustomToolUserFileSnapshot
from onyx.tools.models import DynamicSchemaInfo
from onyx.tools.models import MESSAGE_ID_PLACEHOLDER
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.custom.openapi_parsing import MethodSpec
from onyx.tools.tool_implementations.custom.openapi_parsing import (
openapi_to_method_specs,
@@ -134,17 +133,19 @@ class CustomTool(Tool[None]):
"""Actual execution of the tool"""
def emit_start(self, placement: Placement) -> None:
def emit_start(self, turn_index: int, tab_index: int) -> None:
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=CustomToolStart(tool_name=self._name),
)
)
def run(
self,
placement: Placement,
turn_index: int,
tab_index: int,
override_kwargs: None = None,
**llm_kwargs: Any,
) -> ToolResponse:
@@ -211,7 +212,8 @@ class CustomTool(Tool[None]):
# Emit CustomToolDelta packet
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=CustomToolDelta(
tool_name=self._name,
response_type=response_type,

View File

@@ -13,14 +13,13 @@ from onyx.configs.app_configs import IMAGE_MODEL_NAME
from onyx.db.llm import fetch_existing_llm_providers
from onyx.file_store.utils import build_frontend_file_url
from onyx.file_store.utils import save_files
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.interface import Tool
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.images.models import (
FinalImageGenerationResponse,
)
@@ -121,10 +120,11 @@ class ImageGenerationTool(Tool[None]):
},
}
def emit_start(self, placement: Placement) -> None:
def emit_start(self, turn_index: int, tab_index: int) -> None:
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=ImageGenerationToolStart(),
)
)
@@ -132,7 +132,7 @@ class ImageGenerationTool(Tool[None]):
def _generate_image(
self, prompt: str, shape: ImageShape
) -> tuple[ImageGenerationResponse, Any]:
from litellm import image_generation
from litellm import image_generation # type: ignore
if shape == ImageShape.LANDSCAPE:
if "gpt-image-1" in self.model:
@@ -209,7 +209,8 @@ class ImageGenerationTool(Tool[None]):
def run(
self,
placement: Placement,
turn_index: int,
tab_index: int,
override_kwargs: None = None,
**llm_kwargs: Any,
) -> ToolResponse:
@@ -258,7 +259,8 @@ class ImageGenerationTool(Tool[None]):
# Emit a heartbeat packet to prevent timeout
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=ImageGenerationToolHeartbeat(),
)
)
@@ -302,7 +304,8 @@ class ImageGenerationTool(Tool[None]):
# Emit final packet with generated images
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=ImageGenerationFinal(images=generated_images_metadata),
)
)

View File

@@ -4,9 +4,8 @@ from sqlalchemy.orm import Session
from onyx.chat.emitter import Emitter
from onyx.db.kg_config import get_kg_config_settings
from onyx.server.query_and_chat.placement import Placement
from onyx.tools.interface import Tool
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -69,12 +68,13 @@ class KnowledgeGraphTool(Tool[None]):
},
}
def emit_start(self, placement: Placement) -> None:
def emit_start(self, turn_index: int, tab_index: int) -> None:
raise NotImplementedError("KnowledgeGraphTool.emit_start is not implemented.")
def run(
self,
placement: Placement,
turn_index: int,
tab_index: int,
override_kwargs: None = None,
**llm_kwargs: Any,
) -> ToolResponse:

View File

@@ -6,13 +6,12 @@ from onyx.db.enums import MCPAuthenticationType
from onyx.db.enums import MCPTransport
from onyx.db.models import MCPConnectionConfig
from onyx.db.models import MCPServer
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.interface import Tool
from onyx.tools.models import CustomToolCallSummary
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.mcp.mcp_client import call_mcp_tool
from onyx.utils.logger import setup_logger
@@ -88,17 +87,19 @@ class MCPTool(Tool[None]):
},
}
def emit_start(self, placement: Placement) -> None:
def emit_start(self, turn_index: int, tab_index: int) -> None:
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=CustomToolStart(tool_name=self._name),
)
)
def run(
self,
placement: Placement,
turn_index: int,
tab_index: int,
override_kwargs: None = None,
**llm_kwargs: Any,
) -> ToolResponse:
@@ -145,7 +146,8 @@ class MCPTool(Tool[None]):
# Emit CustomToolDelta packet
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=CustomToolDelta(
tool_name=self._name,
response_type="json",
@@ -180,7 +182,8 @@ class MCPTool(Tool[None]):
# Emit CustomToolDelta packet
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=CustomToolDelta(
tool_name=self._name,
response_type="json",
@@ -234,7 +237,8 @@ class MCPTool(Tool[None]):
# Emit CustomToolDelta packet
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=CustomToolDelta(
tool_name=self._name,
response_type="json",

View File

@@ -9,14 +9,13 @@ from onyx.chat.emitter import Emitter
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SearchDocsResponse
from onyx.context.search.utils import convert_inference_sections_to_search_docs
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OpenUrlDocuments
from onyx.server.query_and_chat.streaming_models import OpenUrlStart
from onyx.server.query_and_chat.streaming_models import OpenUrlUrls
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.interface import Tool
from onyx.tools.models import OpenURLToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.open_url.models import WebContentProvider
from onyx.tools.tool_implementations.web_search.providers import (
get_default_content_provider,
@@ -179,25 +178,28 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
},
}
def emit_start(self, placement: Placement) -> None:
def emit_start(self, turn_index: int, tab_index: int) -> None:
"""Emit start packet to signal tool has started."""
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=OpenUrlStart(),
)
)
def run(
self,
placement: Placement,
turn_index: int,
tab_index: int,
override_kwargs: OpenURLToolOverrideKwargs,
**llm_kwargs: Any,
) -> ToolResponse:
"""Execute the open URL tool to fetch content from the specified URLs.
Args:
placement: The placement info (turn_index and tab_index) for this tool call.
turn_index: The current turn index in the conversation.
tab_index: The tab index for parallel tool calls.
override_kwargs: Override arguments including starting citation number
and existing citation_mapping to reuse citations for already-cited URLs.
**llm_kwargs: Arguments provided by the LLM, including the 'urls' field.
@@ -209,7 +211,8 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=OpenUrlUrls(urls=urls),
)
)
@@ -245,7 +248,8 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
# Emit documents packet AFTER crawling completes
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=OpenUrlDocuments(documents=search_docs),
)
)

View File

@@ -14,15 +14,14 @@ from onyx.configs.app_configs import CODE_INTERPRETER_MAX_OUTPUT_LENGTH
from onyx.configs.constants import FileOrigin
from onyx.file_store.utils import build_full_frontend_file_url
from onyx.file_store.utils import get_default_file_store
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.tools.interface import Tool
from onyx.tools.models import LlmPythonExecutionResult
from onyx.tools.models import PythonExecutionFile
from onyx.tools.models import PythonToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
@@ -115,7 +114,7 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
},
}
def emit_start(self, placement: Placement) -> None:
def emit_start(self, turn_index: int, tab_index: int) -> None:
"""Emit start packet for this tool. Code will be emitted in run() method."""
# Note: PythonToolStart requires code, but we don't have it in emit_start
# The code is available in run() method via llm_kwargs
@@ -123,7 +122,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
def run(
self,
placement: Placement,
turn_index: int,
tab_index: int,
override_kwargs: PythonToolOverrideKwargs,
**llm_kwargs: Any,
) -> ToolResponse:
@@ -131,7 +131,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
Execute Python code in the Code Interpreter service.
Args:
placement: The placement info (turn_index and tab_index) for this tool call.
turn_index: The turn index for this tool execution
tab_index: The tab index for parallel tool calls
override_kwargs: Contains chat_files to stage for execution
**llm_kwargs: Contains 'code' parameter from LLM
@@ -144,7 +145,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
# Emit start event with the code
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=PythonToolStart(code=code),
)
)
@@ -251,7 +253,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
# Emit delta with stdout/stderr and generated files
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=PythonToolDelta(
stdout=truncated_stdout,
stderr=truncated_stderr,
@@ -286,7 +289,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
# Emit error delta
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=PythonToolDelta(
stdout="",
stderr=error_msg,

View File

@@ -64,14 +64,13 @@ from onyx.secondary_llm_flows.document_filter import select_chunks_for_relevance
from onyx.secondary_llm_flows.document_filter import select_sections_for_expansion
from onyx.secondary_llm_flows.query_expansion import keyword_query_expansion
from onyx.secondary_llm_flows.query_expansion import semantic_query_rephrase
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SearchToolDocumentsDelta
from onyx.server.query_and_chat.streaming_models import SearchToolQueriesDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.tools.interface import Tool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.constants import (
KEYWORD_QUERY_HYBRID_ALPHA,
)
@@ -303,19 +302,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
@classmethod
def is_available(cls, db_session: Session) -> bool:
"""Check if search tool is available.
The search tool is available if ANY of the following exist:
- Regular connectors (team knowledge)
- Federated connectors (e.g., Slack)
- User files (User Knowledge mode)
"""
from onyx.db.connector import check_user_files_exist
return (
check_connectors_exist(db_session)
or check_federated_connectors_exist(db_session)
or check_user_files_exist(db_session)
"""Check if search tool is available by verifying connectors exist."""
return check_connectors_exist(db_session) or check_federated_connectors_exist(
db_session
)
@property
@@ -356,10 +345,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
},
}
def emit_start(self, placement: Placement) -> None:
def emit_start(self, turn_index: int, tab_index: int) -> None:
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=SearchToolStart(),
)
)
@@ -367,7 +357,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
@log_function_time(print_only=True)
def run(
self,
placement: Placement,
turn_index: int,
tab_index: int,
override_kwargs: SearchToolOverrideKwargs,
**llm_kwargs: Any,
) -> ToolResponse:
@@ -487,7 +478,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
# Emit the queries early so the UI can display them immediately
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=SearchToolQueriesDelta(
queries=all_queries,
),
@@ -595,7 +587,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=SearchToolDocumentsDelta(
documents=final_ui_docs,
),

View File

@@ -9,14 +9,13 @@ from onyx.context.search.models import SearchDocsResponse
from onyx.context.search.utils import convert_inference_sections_to_search_docs
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.web_search import fetch_active_web_search_provider
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SearchToolDocumentsDelta
from onyx.server.query_and_chat.streaming_models import SearchToolQueriesDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.tools.interface import Tool
from onyx.tools.models import ToolResponse
from onyx.tools.models import WebSearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.utils import (
convert_inference_sections_to_llm_string,
)
@@ -111,10 +110,11 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
},
}
def emit_start(self, placement: Placement) -> None:
def emit_start(self, turn_index: int, tab_index: int) -> None:
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=SearchToolStart(is_internet_search=True),
)
)
@@ -129,7 +129,8 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
def run(
self,
placement: Placement,
turn_index: int,
tab_index: int,
override_kwargs: WebSearchToolOverrideKwargs,
**llm_kwargs: Any,
) -> ToolResponse:
@@ -139,7 +140,8 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
# Emit queries
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=SearchToolQueriesDelta(queries=queries),
)
)
@@ -203,7 +205,8 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
# Emit documents
self.emitter.emit(
Packet(
placement=placement,
turn_index=turn_index,
tab_index=tab_index,
obj=SearchToolDocumentsDelta(documents=search_docs),
)
)

View File

@@ -9,13 +9,13 @@ from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDocsResponse
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.interface import Tool
from onyx.tools.models import ChatMinimalTextMessage
from onyx.tools.models import OpenURLToolOverrideKwargs
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.models import WebSearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
@@ -73,8 +73,9 @@ def _merge_tool_calls(tool_calls: list[ToolCallKickoff]) -> list[ToolCallKickoff
tool_call_id=calls[0].tool_call_id, # Use first call's ID
tool_name=tool_name,
tool_args=merged_args,
# Use first call's placement since merged calls become a single call
placement=calls[0].placement,
turn_index=calls[0].turn_index,
# Use first call's tab_index since merged calls become a single call
tab_index=calls[0].tab_index,
)
merged_calls.append(merged_call)
else:
@@ -93,12 +94,16 @@ def _run_single_tool(
This function is designed to be run in parallel via run_functions_tuples_in_parallel.
"""
turn_index = tool_call.turn_index
tab_index = tool_call.tab_index
with function_span(tool.name) as span_fn:
span_fn.span_data.input = str(tool_call.tool_args)
try:
tool_response = tool.run(
placement=tool_call.placement,
turn_index=turn_index,
override_kwargs=override_kwargs,
tab_index=tab_index,
**tool_call.tool_args,
)
span_fn.span_data.output = tool_response.llm_facing_response
@@ -122,7 +127,8 @@ def _run_single_tool(
# Emit SectionEnd after tool completes (success or failure)
tool.emitter.emit(
Packet(
placement=tool_call.placement,
turn_index=turn_index,
tab_index=tab_index,
obj=SectionEnd(),
)
)
@@ -204,7 +210,7 @@ def run_tool_calls(
tool = tools_by_name[tool_call.tool_name]
# Emit the tool start packet before running the tool
tool.emit_start(placement=tool_call.placement)
tool.emit_start(turn_index=tool_call.turn_index, tab_index=tool_call.tab_index)
override_kwargs: (
SearchToolOverrideKwargs

View File

@@ -9,7 +9,7 @@ from onyx.db.models import LLMProvider
from onyx.llm.utils import find_model_obj
from onyx.llm.utils import get_model_map
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.tools.interface import Tool
from onyx.tools.tool import Tool
def explicit_tool_calling_supported(model_provider: str, model_name: str) -> bool:

View File

@@ -74,20 +74,20 @@ class BraintrustTracingProcessor(TracingProcessor):
current_context = braintrust.current_span()
if current_context != NOOP_SPAN:
self._spans[trace.trace_id] = current_context.start_span(
self._spans[trace.trace_id] = current_context.start_span( # type: ignore[assignment]
name=trace.name,
span_attributes={"type": "task", "name": trace.name},
metadata=metadata,
)
elif self._logger is not None:
self._spans[trace.trace_id] = self._logger.start_span(
self._spans[trace.trace_id] = self._logger.start_span( # type: ignore[assignment]
span_attributes={"type": "task", "name": trace.name},
span_id=trace.trace_id,
root_span_id=trace.trace_id,
metadata=metadata,
)
else:
self._spans[trace.trace_id] = braintrust.start_span(
self._spans[trace.trace_id] = braintrust.start_span( # type: ignore[assignment]
id=trace.trace_id,
span_attributes={"type": "task", "name": trace.name},
metadata=metadata,

View File

@@ -208,7 +208,7 @@ def _as_utc_nano(dt: datetime) -> int:
def _get_span_name(obj: Span[Any]) -> str:
if hasattr(data := obj.span_data, "name") and isinstance(name := data.name, str):
return name
return obj.span_data.type
return obj.span_data.type # type: ignore[no-any-return]
def _get_span_kind(obj: SpanData) -> str:

View File

@@ -221,7 +221,7 @@ def run_functions_tuples_in_parallel(
results.append((index, future.result()))
except Exception as e:
logger.exception(f"Function at index {index} failed due to {e}")
results.append((index, None))
results.append((index, None)) # type: ignore
if not allow_failures:
raise
@@ -336,7 +336,7 @@ def run_with_timeout(
if task.is_alive():
task.end()
return task.result
return task.result # type: ignore
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
@@ -352,7 +352,7 @@ def run_in_background(
"""
context = contextvars.copy_context()
# Timeout not used in the non-blocking case
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
task = TimeoutThread(-1, context.run, func, *args, **kwargs) # type: ignore
task.start()
return cast(TimeoutThread[R], task)

View File

@@ -14,7 +14,6 @@ plugins = "sqlalchemy.ext.mypy.plugin"
mypy_path = "backend"
explicit_package_bases = true
disallow_untyped_defs = true
warn_unused_ignores = true
enable_error_code = ["possibly-undefined"]
strict_equality = true
exclude = [

View File

@@ -82,10 +82,6 @@ colorama==0.4.6 ; sys_platform == 'win32'
# tqdm
comm==0.2.3
# via ipykernel
contourpy==1.3.3
# via matplotlib
cycler==0.12.1
# via matplotlib
debugpy==1.8.17
# via ipykernel
decorator==5.2.1
@@ -116,8 +112,6 @@ filelock==3.20.1
# via
# huggingface-hub
# virtualenv
fonttools==4.61.1
# via matplotlib
frozenlist==1.8.0
# via
# aiohttp
@@ -239,16 +233,12 @@ jupyter-core==5.9.1
# via
# ipykernel
# jupyter-client
kiwisolver==1.4.9
# via matplotlib
litellm==1.79.0
# via onyx
manygo==0.2.0
# via onyx
markupsafe==3.0.3
# via jinja2
matplotlib==3.10.8
# via onyx
matplotlib-inline==0.2.1
# via
# ipykernel
@@ -271,8 +261,6 @@ nodeenv==1.9.1
# via pre-commit
numpy==1.26.4
# via
# contourpy
# matplotlib
# pandas-stubs
# shapely
# voyageai
@@ -294,7 +282,6 @@ packaging==24.2
# hatchling
# huggingface-hub
# ipykernel
# matplotlib
# pytest
pandas-stubs==2.2.3.241009
# via onyx
@@ -308,8 +295,6 @@ pathspec==0.12.1
# hatchling
pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
# via ipython
pillow==12.0.0
# via matplotlib
platformdirs==4.5.0
# via
# black
@@ -378,8 +363,6 @@ pygments==2.19.2
# via
# ipython
# ipython-pygments-lexers
pyparsing==3.2.5
# via matplotlib
pytest==8.3.5
# via
# onyx
@@ -398,7 +381,6 @@ python-dateutil==2.8.2
# botocore
# google-cloud-bigquery
# jupyter-client
# matplotlib
python-dotenv==1.1.1
# via
# litellm

View File

@@ -31,7 +31,7 @@ def blob_connector(request: pytest.FixtureRequest) -> BlobStorageConnector:
)
"""
try:
bucket_type, bucket_name, *rest = request.param
bucket_type, bucket_name, *rest = request.param # type: ignore[misc]
except Exception as e:
raise AssertionError(
"blob_connector requires (BlobType, bucket_name, [init_kwargs])"

View File

@@ -36,7 +36,7 @@ def _prepare_connector(exclude: bool) -> GoogleDriveConnector:
exclude_domain_link_only=exclude,
)
connector._creds = object() # type: ignore[assignment]
connector._primary_admin_email = "admin@example.com"
connector._primary_admin_email = "admin@example.com" # type: ignore[attr-defined]
return connector

View File

@@ -1,442 +0,0 @@
"""
External dependency unit tests for document processing job priority.
Tests that first-time indexing connectors (no last_successful_index_time)
get higher priority than re-indexing jobs from connectors that have
previously completed indexing.
Uses real Redis for locking and real database objects for CC pairs and search settings.
"""
from datetime import datetime
from datetime import timezone
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.docfetching.task_creation_utils import (
try_creating_docfetching_task,
)
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.connectors.models import InputType
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import EmbeddingPrecision
from onyx.db.enums import IndexModelStatus
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import SearchSettings
from onyx.redis.redis_pool import get_redis_client
from tests.external_dependency_unit.constants import TEST_TENANT_ID
def _create_test_connector(db_session: Session, name: str) -> Connector:
"""Create a test connector with all required fields."""
connector = Connector(
name=name,
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={},
refresh_freq=3600,
)
db_session.add(connector)
db_session.commit()
db_session.refresh(connector)
return connector
def _create_test_credential(db_session: Session) -> Credential:
"""Create a test credential with all required fields."""
credential = Credential(
name=f"test_credential_{uuid4().hex[:8]}",
source=DocumentSource.FILE,
credential_json={},
admin_public=True,
)
db_session.add(credential)
db_session.commit()
db_session.refresh(credential)
return credential
def _create_test_cc_pair(
db_session: Session,
connector: Connector,
credential: Credential,
status: ConnectorCredentialPairStatus,
name: str,
last_successful_index_time: datetime | None = None,
) -> ConnectorCredentialPair:
"""Create a connector credential pair with the specified status."""
cc_pair = ConnectorCredentialPair(
name=name,
connector_id=connector.id,
credential_id=credential.id,
status=status,
access_type=AccessType.PUBLIC,
last_successful_index_time=last_successful_index_time,
)
db_session.add(cc_pair)
db_session.commit()
db_session.refresh(cc_pair)
return cc_pair
def _create_test_search_settings(
db_session: Session, index_name: str
) -> SearchSettings:
"""Create test search settings with all required fields."""
search_settings = SearchSettings(
model_name="test-model",
model_dim=768,
normalize=True,
query_prefix="",
passage_prefix="",
status=IndexModelStatus.PRESENT,
index_name=index_name,
embedding_precision=EmbeddingPrecision.FLOAT,
)
db_session.add(search_settings)
db_session.commit()
db_session.refresh(search_settings)
return search_settings
class TestDocfetchingTaskPriorityWithRealObjects:
"""
Tests for document fetching task priority based on last_successful_index_time.
Uses real Redis for locking and real database objects for CC pairs
and search settings.
"""
@pytest.mark.parametrize(
"has_successful_index,expected_priority",
[
# First-time indexing (no last_successful_index_time) should get HIGH priority
(False, OnyxCeleryPriority.HIGH),
# Re-indexing (has last_successful_index_time) should get MEDIUM priority
(True, OnyxCeleryPriority.MEDIUM),
],
)
@patch(
"onyx.background.celery.tasks.docfetching.task_creation_utils.IndexingCoordination.try_create_index_attempt"
)
def test_priority_based_on_last_successful_index_time(
self,
mock_try_create_index_attempt: MagicMock,
db_session: Session,
has_successful_index: bool,
expected_priority: OnyxCeleryPriority,
) -> None:
"""
Test that first-time indexing connectors get higher priority than re-indexing.
Priority is determined by last_successful_index_time:
- None (never indexed): HIGH priority
- Has timestamp (previously indexed): MEDIUM priority
Uses real Redis for locking and real database objects.
"""
# Create unique names to avoid conflicts between test runs
unique_suffix = uuid4().hex[:8]
# Determine last_successful_index_time based on the test case
last_successful_index_time = (
datetime.now(timezone.utc) if has_successful_index else None
)
# Create real database objects
connector = _create_test_connector(
db_session, f"test_connector_{has_successful_index}_{unique_suffix}"
)
credential = _create_test_credential(db_session)
cc_pair = _create_test_cc_pair(
db_session,
connector,
credential,
ConnectorCredentialPairStatus.ACTIVE,
name=f"test_cc_pair_{has_successful_index}_{unique_suffix}",
last_successful_index_time=last_successful_index_time,
)
search_settings = _create_test_search_settings(
db_session, f"test_index_{unique_suffix}"
)
# Mock the index attempt creation to return a valid ID
mock_try_create_index_attempt.return_value = 12345
# Mock celery app to capture task submission
mock_celery_app = MagicMock()
mock_celery_app.send_task.return_value = MagicMock()
# Use real Redis client
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
# Call the function with real objects
result = try_creating_docfetching_task(
celery_app=mock_celery_app,
cc_pair=cc_pair,
search_settings=search_settings,
reindex=False,
db_session=db_session,
r=redis_client,
tenant_id=TEST_TENANT_ID,
)
# Verify task was created
assert result == 12345
# Verify send_task was called with the expected priority
mock_celery_app.send_task.assert_called_once()
call_kwargs = mock_celery_app.send_task.call_args
actual_priority = call_kwargs.kwargs["priority"]
assert actual_priority == expected_priority, (
f"Expected priority {expected_priority} for has_successful_index={has_successful_index}, "
f"but got {actual_priority}"
)
@patch(
"onyx.background.celery.tasks.docfetching.task_creation_utils.IndexingCoordination.try_create_index_attempt"
)
def test_no_task_created_when_deleting(
self,
mock_try_create_index_attempt: MagicMock,
db_session: Session,
) -> None:
"""Test that no task is created when connector is in DELETING status."""
unique_suffix = uuid4().hex[:8]
connector = _create_test_connector(
db_session, f"test_connector_deleting_{unique_suffix}"
)
credential = _create_test_credential(db_session)
cc_pair = _create_test_cc_pair(
db_session,
connector,
credential,
ConnectorCredentialPairStatus.DELETING,
name=f"test_cc_pair_deleting_{unique_suffix}",
)
search_settings = _create_test_search_settings(
db_session, f"test_index_deleting_{unique_suffix}"
)
mock_celery_app = MagicMock()
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
result = try_creating_docfetching_task(
celery_app=mock_celery_app,
cc_pair=cc_pair,
search_settings=search_settings,
reindex=False,
db_session=db_session,
r=redis_client,
tenant_id=TEST_TENANT_ID,
)
# Verify no task was created
assert result is None
mock_celery_app.send_task.assert_not_called()
mock_try_create_index_attempt.assert_not_called()
@patch(
"onyx.background.celery.tasks.docfetching.task_creation_utils.IndexingCoordination.try_create_index_attempt"
)
def test_redis_lock_prevents_concurrent_task_creation(
self,
mock_try_create_index_attempt: MagicMock,
db_session: Session,
) -> None:
"""
Test that the Redis lock prevents concurrent task creation attempts.
This test uses real Redis to verify the locking mechanism works correctly.
When the lock is already held, the function should return None without
attempting to create a task.
"""
unique_suffix = uuid4().hex[:8]
connector = _create_test_connector(
db_session, f"test_connector_lock_{unique_suffix}"
)
credential = _create_test_credential(db_session)
cc_pair = _create_test_cc_pair(
db_session,
connector,
credential,
ConnectorCredentialPairStatus.INITIAL_INDEXING,
name=f"test_cc_pair_lock_{unique_suffix}",
)
search_settings = _create_test_search_settings(
db_session, f"test_index_lock_{unique_suffix}"
)
mock_try_create_index_attempt.return_value = 12345
mock_celery_app = MagicMock()
mock_celery_app.send_task.return_value = MagicMock()
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
# Acquire the lock before calling the function
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
lock = redis_client.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=30,
)
try:
acquired = lock.acquire(blocking=False)
assert acquired, "Failed to acquire lock for test"
# Now try to create a task - should fail because lock is held
result = try_creating_docfetching_task(
celery_app=mock_celery_app,
cc_pair=cc_pair,
search_settings=search_settings,
reindex=False,
db_session=db_session,
r=redis_client,
tenant_id=TEST_TENANT_ID,
)
# Should return None because lock couldn't be acquired
assert result is None
mock_celery_app.send_task.assert_not_called()
finally:
# Always release the lock
if lock.owned():
lock.release()
@patch(
"onyx.background.celery.tasks.docfetching.task_creation_utils.IndexingCoordination.try_create_index_attempt"
)
def test_lock_released_after_successful_task_creation(
self,
mock_try_create_index_attempt: MagicMock,
db_session: Session,
) -> None:
"""
Test that the Redis lock is released after successful task creation.
This verifies that subsequent calls can acquire the lock and create tasks.
"""
unique_suffix = uuid4().hex[:8]
connector = _create_test_connector(
db_session, f"test_connector_release_{unique_suffix}"
)
credential = _create_test_credential(db_session)
cc_pair = _create_test_cc_pair(
db_session,
connector,
credential,
ConnectorCredentialPairStatus.INITIAL_INDEXING,
name=f"test_cc_pair_release_{unique_suffix}",
)
search_settings = _create_test_search_settings(
db_session, f"test_index_release_{unique_suffix}"
)
mock_try_create_index_attempt.return_value = 12345
mock_celery_app = MagicMock()
mock_celery_app.send_task.return_value = MagicMock()
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
# First call should succeed
result1 = try_creating_docfetching_task(
celery_app=mock_celery_app,
cc_pair=cc_pair,
search_settings=search_settings,
reindex=False,
db_session=db_session,
r=redis_client,
tenant_id=TEST_TENANT_ID,
)
assert result1 == 12345
# Reset mocks for second call
mock_celery_app.reset_mock()
mock_try_create_index_attempt.reset_mock()
mock_try_create_index_attempt.return_value = 67890
# Second call should also succeed (lock was released)
result2 = try_creating_docfetching_task(
celery_app=mock_celery_app,
cc_pair=cc_pair,
search_settings=search_settings,
reindex=False,
db_session=db_session,
r=redis_client,
tenant_id=TEST_TENANT_ID,
)
assert result2 == 67890
# Both calls should have submitted tasks
mock_celery_app.send_task.assert_called_once()
@patch(
"onyx.background.celery.tasks.docfetching.task_creation_utils.IndexingCoordination.try_create_index_attempt"
)
def test_user_file_connector_uses_correct_queue(
self,
mock_try_create_index_attempt: MagicMock,
db_session: Session,
) -> None:
"""
Test that user file connectors use the USER_FILES_INDEXING queue.
"""
from onyx.configs.constants import OnyxCeleryQueues
unique_suffix = uuid4().hex[:8]
connector = _create_test_connector(
db_session, f"test_connector_userfile_{unique_suffix}"
)
credential = _create_test_credential(db_session)
cc_pair = _create_test_cc_pair(
db_session,
connector,
credential,
ConnectorCredentialPairStatus.INITIAL_INDEXING,
name=f"test_cc_pair_userfile_{unique_suffix}",
)
# Mark as user file
cc_pair.is_user_file = True
db_session.commit()
db_session.refresh(cc_pair)
search_settings = _create_test_search_settings(
db_session, f"test_index_userfile_{unique_suffix}"
)
mock_try_create_index_attempt.return_value = 12345
mock_celery_app = MagicMock()
mock_celery_app.send_task.return_value = MagicMock()
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
result = try_creating_docfetching_task(
celery_app=mock_celery_app,
cc_pair=cc_pair,
search_settings=search_settings,
reindex=False,
db_session=db_session,
r=redis_client,
tenant_id=TEST_TENANT_ID,
)
assert result == 12345
mock_celery_app.send_task.assert_called_once()
call_kwargs = mock_celery_app.send_task.call_args
assert call_kwargs.kwargs["queue"] == OnyxCeleryQueues.USER_FILES_INDEXING
# User files with no last_successful_index_time should get HIGH priority
assert call_kwargs.kwargs["priority"] == OnyxCeleryPriority.HIGH

View File

@@ -1,258 +0,0 @@
"""
External dependency unit tests for docprocessing task priority.
Tests that docprocessing tasks spawned by connector_document_extraction
get the correct priority based on last_successful_index_time.
Uses real database objects for CC pairs, search settings, and index attempts.
"""
from datetime import datetime
from datetime import timezone
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.background.indexing.run_docfetching import connector_document_extraction
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.connectors.models import InputType
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import EmbeddingPrecision
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import IndexAttempt
from onyx.db.models import SearchSettings
from tests.external_dependency_unit.constants import TEST_TENANT_ID
def _create_test_connector(db_session: Session, name: str) -> Connector:
"""Create a test connector with all required fields."""
connector = Connector(
name=name,
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={},
refresh_freq=3600,
)
db_session.add(connector)
db_session.commit()
db_session.refresh(connector)
return connector
def _create_test_credential(db_session: Session) -> Credential:
"""Create a test credential with all required fields."""
credential = Credential(
name=f"test_credential_{uuid4().hex[:8]}",
source=DocumentSource.FILE,
credential_json={},
admin_public=True,
)
db_session.add(credential)
db_session.commit()
db_session.refresh(credential)
return credential
def _create_test_cc_pair(
db_session: Session,
connector: Connector,
credential: Credential,
status: ConnectorCredentialPairStatus,
name: str,
last_successful_index_time: datetime | None = None,
) -> ConnectorCredentialPair:
"""Create a connector credential pair with the specified status."""
cc_pair = ConnectorCredentialPair(
name=name,
connector_id=connector.id,
credential_id=credential.id,
status=status,
access_type=AccessType.PUBLIC,
last_successful_index_time=last_successful_index_time,
)
db_session.add(cc_pair)
db_session.commit()
db_session.refresh(cc_pair)
return cc_pair
def _create_test_search_settings(
db_session: Session, index_name: str
) -> SearchSettings:
"""Create test search settings with all required fields."""
search_settings = SearchSettings(
model_name="test-model",
model_dim=768,
normalize=True,
query_prefix="",
passage_prefix="",
status=IndexModelStatus.PRESENT,
index_name=index_name,
embedding_precision=EmbeddingPrecision.FLOAT,
)
db_session.add(search_settings)
db_session.commit()
db_session.refresh(search_settings)
return search_settings
def _create_test_index_attempt(
db_session: Session,
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
from_beginning: bool = False,
) -> IndexAttempt:
"""Create a test index attempt with the specified cc_pair and search_settings."""
index_attempt = IndexAttempt(
connector_credential_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
from_beginning=from_beginning,
status=IndexingStatus.IN_PROGRESS,
celery_task_id=f"test_celery_task_{uuid4().hex[:8]}",
)
db_session.add(index_attempt)
db_session.commit()
db_session.refresh(index_attempt)
return index_attempt
class TestDocprocessingPriorityInDocumentExtraction:
"""
Tests for docprocessing task priority within connector_document_extraction.
Verifies that the priority passed to docprocessing tasks is determined
by last_successful_index_time on the cc_pair.
"""
@pytest.mark.parametrize(
"has_successful_index,expected_priority",
[
# First-time indexing (no last_successful_index_time) should get HIGH priority
(False, OnyxCeleryPriority.HIGH),
# Re-indexing (has last_successful_index_time) should get MEDIUM priority
(True, OnyxCeleryPriority.MEDIUM),
],
)
@patch("onyx.background.indexing.run_docfetching.get_document_batch_storage")
@patch("onyx.background.indexing.run_docfetching.MemoryTracer")
@patch("onyx.background.indexing.run_docfetching._get_connector_runner")
@patch(
"onyx.background.indexing.run_docfetching.get_recent_completed_attempts_for_cc_pair"
)
@patch(
"onyx.background.indexing.run_docfetching.get_last_successful_attempt_poll_range_end"
)
@patch("onyx.background.indexing.run_docfetching.save_checkpoint")
@patch("onyx.background.indexing.run_docfetching.get_latest_valid_checkpoint")
def test_docprocessing_priority_based_on_last_successful_index_time(
self,
mock_get_latest_valid_checkpoint: MagicMock,
mock_save_checkpoint: MagicMock,
mock_get_last_successful_attempt_poll_range_end: MagicMock,
mock_get_recent_completed_attempts: MagicMock,
mock_get_connector_runner: MagicMock,
mock_memory_tracer_class: MagicMock,
mock_get_batch_storage: MagicMock,
db_session: Session,
has_successful_index: bool,
expected_priority: OnyxCeleryPriority,
) -> None:
"""
Test that docprocessing tasks get the correct priority based on
last_successful_index_time.
Priority is determined by last_successful_index_time:
- None (never indexed): HIGH priority
- Has timestamp (previously indexed): MEDIUM priority
Uses real database objects for CC pairs and search settings.
"""
unique_suffix = uuid4().hex[:8]
# Determine last_successful_index_time based on the test case
last_successful_index_time = (
datetime.now(timezone.utc) if has_successful_index else None
)
# Create real database objects
connector = _create_test_connector(
db_session, f"test_connector_docproc_{has_successful_index}_{unique_suffix}"
)
credential = _create_test_credential(db_session)
cc_pair = _create_test_cc_pair(
db_session,
connector,
credential,
ConnectorCredentialPairStatus.ACTIVE,
name=f"test_cc_pair_docproc_{has_successful_index}_{unique_suffix}",
last_successful_index_time=last_successful_index_time,
)
search_settings = _create_test_search_settings(
db_session, f"test_index_docproc_{unique_suffix}"
)
index_attempt = _create_test_index_attempt(
db_session, cc_pair, search_settings, from_beginning=False
)
# Setup mocks
mock_batch_storage = MagicMock()
mock_get_batch_storage.return_value = mock_batch_storage
mock_memory_tracer = MagicMock()
mock_memory_tracer_class.return_value = mock_memory_tracer
# Create checkpoint mocks - initial checkpoint has_more=True, final has_more=False
mock_initial_checkpoint = MagicMock(has_more=True)
mock_final_checkpoint = MagicMock(has_more=False)
# get_latest_valid_checkpoint returns (checkpoint, resuming_from_checkpoint)
mock_get_latest_valid_checkpoint.return_value = (mock_initial_checkpoint, False)
# Create a mock connector runner that yields one document batch
mock_connector = MagicMock()
mock_connector_runner = MagicMock()
mock_connector_runner.connector = mock_connector
# The connector runner yields (document_batch, failure, next_checkpoint)
# We provide one batch of documents to trigger a send_task call
mock_doc = MagicMock()
mock_doc.to_short_descriptor.return_value = "test_doc"
mock_doc.sections = []
mock_connector_runner.run.return_value = iter(
[([mock_doc], None, mock_final_checkpoint)]
)
mock_get_connector_runner.return_value = mock_connector_runner
mock_get_recent_completed_attempts.return_value = iter([])
mock_get_last_successful_attempt_poll_range_end.return_value = 0
# Mock celery app to capture task submission
mock_celery_app = MagicMock()
mock_celery_app.send_task.return_value = MagicMock()
# Call the function
connector_document_extraction(
app=mock_celery_app,
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
tenant_id=TEST_TENANT_ID,
callback=None,
)
# Verify send_task was called with the expected priority for docprocessing
assert mock_celery_app.send_task.called, "send_task should have been called"
call_kwargs = mock_celery_app.send_task.call_args
actual_priority = call_kwargs.kwargs["priority"]
assert actual_priority == expected_priority, (
f"Expected priority {expected_priority} for has_successful_index={has_successful_index}, "
f"but got {actual_priority}"
)

View File

@@ -31,7 +31,6 @@ from onyx.db.models import Persona
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.llm.factory import get_default_llm
from onyx.server.query_and_chat.placement import Placement
from onyx.tools.models import CustomToolCallSummary
from onyx.tools.tool_constructor import construct_tools
from onyx.tools.tool_constructor import SearchToolConfig
@@ -388,9 +387,7 @@ class TestMCPPassThroughOAuth:
):
# Run the tool
response = mcp_tool.run(
placement=Placement(turn_index=0, tab_index=0),
override_kwargs=None,
input="test",
turn_index=0, tab_index=0, override_kwargs=None, input="test"
)
print(response.rich_response)
assert isinstance(response.rich_response, CustomToolCallSummary)

View File

@@ -1,4 +1,4 @@
import generated.onyx_openapi_client.onyx_openapi_client as onyx_api # type: ignore[import-untyped,unused-ignore]
import generated.onyx_openapi_client.onyx_openapi_client as onyx_api # type: ignore[import]
from tests.integration.common_utils.constants import API_SERVER_URL
api_config = onyx_api.Configuration(host=API_SERVER_URL)

View File

@@ -1,143 +0,0 @@
"""
Utilities for testing document access control lists (ACLs) and permissions.
"""
from typing import List
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
from ee.onyx.access.access import _get_access_for_documents
from ee.onyx.db.external_perm import fetch_external_groups_for_user
from onyx.access.utils import prefix_external_group
from onyx.access.utils import prefix_user_email
from onyx.configs.constants import PUBLIC_DOC_PAT
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import User
from onyx.db.users import fetch_user_by_id
from onyx.utils.logger import setup_logger
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
logger = setup_logger()
def get_user_acl(user: User, db_session: Session) -> set[str]:
"""
Get the ACL entries for a user, including their external groups, email, and public doc pattern.
Args:
user: The user object
db_session: Database session
Returns:
Set of ACL entries for the user
"""
db_external_groups = (
fetch_external_groups_for_user(db_session, user.id) if user else []
)
prefixed_external_groups = [
prefix_external_group(db_external_group.external_user_group_id)
for db_external_group in db_external_groups
]
user_acl = set(prefixed_external_groups)
user_acl.update({prefix_user_email(user.email), PUBLIC_DOC_PAT})
return user_acl
def get_user_document_access_via_acl(
test_user: DATestUser, document_ids: List[str], db_session: Session
) -> List[str]:
"""
Determine which documents a user can access by comparing user ACL with document ACLs.
This is a more reliable method than search-based verification as it directly checks
permission logic without depending on search relevance or ranking.
Args:
test_user: The test user to check access for
document_ids: List of document IDs to check
db_session: Database session
Returns:
List of document IDs that the user can access
"""
# Get the actual User object from the database
user = fetch_user_by_id(db_session, UUID(test_user.id))
if not user:
logger.error(f"Could not find user with ID {test_user.id}")
return []
user_acl = get_user_acl(user, db_session)
logger.info(f"User {user.email} ACL entries: {user_acl}")
# Get document access information
doc_access_map = _get_access_for_documents(document_ids, db_session)
logger.info(f"Found access info for {len(doc_access_map)} documents")
accessible_docs = []
for doc_id, doc_access in doc_access_map.items():
doc_acl = doc_access.to_acl()
logger.info(f"Document {doc_id} ACL: {doc_acl}")
# Check if user has any matching ACL entry
if user_acl.intersection(doc_acl):
accessible_docs.append(doc_id)
logger.info(f"User {user.email} has access to document {doc_id}")
else:
logger.info(f"User {user.email} does NOT have access to document {doc_id}")
return accessible_docs
def get_all_connector_documents(
cc_pair: DATestCCPair, db_session: Session
) -> List[str]:
"""
Get all document IDs for a given connector/credential pair.
Args:
cc_pair: The connector-credential pair
db_session: Database session
Returns:
List of document IDs
"""
stmt = select(DocumentByConnectorCredentialPair.id).where(
DocumentByConnectorCredentialPair.connector_id == cc_pair.connector_id,
DocumentByConnectorCredentialPair.credential_id == cc_pair.credential_id,
)
result = db_session.execute(stmt)
document_ids = [row[0] for row in result.fetchall()]
logger.info(
f"Found {len(document_ids)} documents for connector {cc_pair.connector_id}"
)
return document_ids
def get_documents_by_permission_type(
document_ids: List[str], db_session: Session
) -> List[str]:
"""
Categorize documents by their permission types and return public documents.
Args:
document_ids: List of document IDs to check
db_session: Database session
Returns:
List of document IDs that are public
"""
doc_access_map = _get_access_for_documents(document_ids, db_session)
public_docs = []
for doc_id, doc_access in doc_access_map.items():
if doc_access.is_public:
public_docs.append(doc_id)
return public_docs

View File

@@ -5,7 +5,7 @@ from uuid import uuid4
import requests
import generated.onyx_openapi_client.onyx_openapi_client as api # type: ignore[import-untyped,unused-ignore]
import generated.onyx_openapi_client.onyx_openapi_client as api # type: ignore[import]
from onyx.connectors.models import InputType
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus

View File

@@ -49,14 +49,6 @@ class StreamPacketObj(TypedDict, total=False):
documents: list[dict[str, Any]]
class PlacementData(TypedDict, total=False):
"""Structure for packet placement information."""
turn_index: int
tab_index: int
sub_turn_index: int | None
class StreamPacketData(TypedDict, total=False):
"""Structure for streaming response packets."""
@@ -64,7 +56,7 @@ class StreamPacketData(TypedDict, total=False):
error: str
stack_trace: str
obj: StreamPacketObj
placement: PlacementData
turn_index: int
class ChatSessionManager:
@@ -200,7 +192,7 @@ class ChatSessionManager:
(
data.get("ind")
if data.get("ind") is not None
else data.get("placement", {}).get("turn_index")
else data.get("turn_index")
),
)
)

View File

@@ -1,149 +0,0 @@
import os
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import InputType
from onyx.db.enums import AccessType
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestCredential
from tests.integration.common_utils.test_models import DATestUser
GitHubTestEnvSetupTuple = tuple[
DATestUser, # admin_user
DATestUser, # test_user_1
DATestUser, # test_user_2
DATestCredential, # github_credential
DATestConnector, # github_connector
DATestCCPair, # github_cc_pair
]
def _get_github_test_tokens() -> list[str]:
"""
Returns a list of GitHub tokens to run the GitHub connector suite against.
Minimal setup:
- Set GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN (token1)
Optional:
- Set GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC (token2 / classic)
If the classic token is provided, the GitHub suite will run twice (once per token).
"""
token_1 = os.environ.get("GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN")
# Prefer the new "classic" name, but keep backward compatibility.
token_2 = os.environ.get("GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC")
tokens: list[str] = []
if token_1:
tokens.append(token_1)
if token_2:
tokens.append(token_2)
return tokens
@pytest.fixture(scope="module", params=_get_github_test_tokens())
def github_access_token(request: pytest.FixtureRequest) -> str:
tokens = _get_github_test_tokens()
if not tokens:
pytest.skip(
"Skipping GitHub tests due to missing env vars "
"GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN and "
"GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC"
)
return request.param
@pytest.fixture(scope="module")
def github_test_env_setup(
github_access_token: str,
) -> Generator[GitHubTestEnvSetupTuple]:
"""
Create a complete GitHub test environment with:
- 3 users with email IDs from environment variables
- GitHub credentials using ACCESS_TOKEN_GITHUB from environment
- GitHub connector configured for testing
- Connector-Credential pair linking them together
Returns:
Tuple containing: (admin_user, test_user_1, test_user_2, github_credential, github_connector, github_cc_pair)
"""
# Reset all resources before setting up the test environment
reset_all()
# Get user emails from environment (with fallbacks)
admin_email = os.environ.get("GITHUB_ADMIN_EMAIL")
test_user_1_email = os.environ.get("GITHUB_TEST_USER_1_EMAIL")
test_user_2_email = os.environ.get("GITHUB_TEST_USER_2_EMAIL")
if not admin_email or not test_user_1_email or not test_user_2_email:
pytest.skip(
"Skipping GitHub test environment setup due to missing environment variables"
)
# Create users
admin_user: DATestUser = UserManager.create(email=admin_email)
test_user_1: DATestUser = UserManager.create(email=test_user_1_email)
test_user_2: DATestUser = UserManager.create(email=test_user_2_email)
# Create LLM provider - required for document search to work
LLMProviderManager.create(user_performing_action=admin_user)
# Create GitHub credentials
github_credentials = {
"github_access_token": github_access_token,
}
github_credential: DATestCredential = CredentialManager.create(
source=DocumentSource.GITHUB,
credential_json=github_credentials,
user_performing_action=admin_user,
)
# Create GitHub connector
github_connector: DATestConnector = ConnectorManager.create(
name="GitHub Test Connector",
input_type=InputType.POLL,
source=DocumentSource.GITHUB,
connector_specific_config={
"repo_owner": "permission-sync-test",
"include_prs": True,
"repositories": "perm-sync-test-minimal",
"include_issues": True,
},
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
# Create CC pair linking connector and credential
github_cc_pair: DATestCCPair = CCPairManager.create(
credential_id=github_credential.id,
connector_id=github_connector.id,
name="GitHub Test CC Pair",
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
# Wait for initial indexing to complete
# GitHub API operations can be slow due to rate limiting and network latency
# Use a longer timeout for initial indexing to avoid flaky test failures
before = datetime.now(tz=timezone.utc)
CCPairManager.wait_for_indexing_completion(
cc_pair=github_cc_pair,
after=before,
user_performing_action=admin_user,
timeout=900,
)
yield admin_user, test_user_1, test_user_2, github_credential, github_connector, github_cc_pair

View File

@@ -1,341 +0,0 @@
import os
from datetime import datetime
from datetime import timezone
import pytest
from github import Github
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.utils.logger import setup_logger
from tests.integration.common_utils.document_acl import (
get_all_connector_documents,
)
from tests.integration.common_utils.document_acl import (
get_user_document_access_via_acl,
)
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.connector_job_tests.github.conftest import (
GitHubTestEnvSetupTuple,
)
from tests.integration.connector_job_tests.github.utils import GitHubManager
logger = setup_logger()
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission tests are enterprise only",
)
def test_github_private_repo_permission_sync(
github_test_env_setup: GitHubTestEnvSetupTuple,
) -> None:
(
admin_user,
test_user_1,
test_user_2,
github_credential,
github_connector,
github_cc_pair,
) = github_test_env_setup
# Create GitHub client from credential
github_access_token = github_credential.credential_json["github_access_token"]
github_client = Github(github_access_token)
github_manager = GitHubManager(github_client)
# Get repository configuration from connector
repo_owner = github_connector.connector_specific_config["repo_owner"]
repo_name = github_connector.connector_specific_config["repositories"]
success = github_manager.change_repository_visibility(
repo_owner=repo_owner, repo_name=repo_name, visibility="private"
)
if not success:
pytest.fail(f"Failed to change repository {repo_owner}/{repo_name} to private")
# Add test-team to repository at the start
logger.info(f"Adding test-team to repository {repo_owner}/{repo_name}")
team_added = github_manager.add_team_to_repository(
repo_owner=repo_owner,
repo_name=repo_name,
team_slug="test-team",
permission="pull",
)
if not team_added:
logger.warning(
f"Failed to add test-team to repository {repo_owner}/{repo_name}"
)
try:
after = datetime.now(timezone.utc)
CCPairManager.sync(
cc_pair=github_cc_pair,
user_performing_action=admin_user,
)
# Use a longer timeout for GitHub permission sync operations
# GitHub API operations can be slow, especially with rate limiting
# This accounts for document sync, group sync, and vespa sync operations
CCPairManager.wait_for_sync(
cc_pair=github_cc_pair,
user_performing_action=admin_user,
after=after,
should_wait_for_group_sync=True,
timeout=900,
)
# ACL-based verification
with get_session_with_current_tenant() as db_session:
# Get all documents for this connector
all_document_ids = get_all_connector_documents(github_cc_pair, db_session)
# Test access for both users using ACL verification
accessible_docs_user1 = get_user_document_access_via_acl(
test_user=test_user_1,
document_ids=all_document_ids,
db_session=db_session,
)
accessible_docs_user2 = get_user_document_access_via_acl(
test_user=test_user_2,
document_ids=all_document_ids,
db_session=db_session,
)
logger.info(
f"test_user_1 has access to {len(accessible_docs_user1)} documents"
)
logger.info(
f"test_user_2 has access to {len(accessible_docs_user2)} documents"
)
# test_user_1 (part of test-team) should have access
# test_user_2 (not part of test-team) should NOT have access
assert len(accessible_docs_user1) > 0, (
f"test_user_1 should have access to private repository documents. "
f"Found {len(accessible_docs_user1)} accessible docs out of "
f"{len(all_document_ids)} total"
)
assert len(accessible_docs_user2) == 0, (
f"test_user_2 should NOT have access to private repository documents. "
f"Found {len(accessible_docs_user2)} accessible docs out of "
f"{len(all_document_ids)} total"
)
finally:
# Remove test-team from repository at the end
logger.info(f"Removing test-team from repository {repo_owner}/{repo_name}")
team_removed = github_manager.remove_team_from_repository(
repo_owner=repo_owner, repo_name=repo_name, team_slug="test-team"
)
if not team_removed:
logger.warning(
f"Failed to remove test-team from repository {repo_owner}/{repo_name}"
)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission tests are enterprise only",
)
def test_github_public_repo_permission_sync(
github_test_env_setup: GitHubTestEnvSetupTuple,
) -> None:
"""
Test that when a repository is changed to public, both users can access the documents.
"""
(
admin_user,
test_user_1,
test_user_2,
github_credential,
github_connector,
github_cc_pair,
) = github_test_env_setup
# Create GitHub client from credential
github_access_token = github_credential.credential_json["github_access_token"]
github_client = Github(github_access_token)
github_manager = GitHubManager(github_client)
# Get repository configuration from connector
repo_owner = github_connector.connector_specific_config["repo_owner"]
repo_name = github_connector.connector_specific_config["repositories"]
# Change repository to public
logger.info(f"Changing repository {repo_owner}/{repo_name} to public")
success = github_manager.change_repository_visibility(
repo_owner=repo_owner, repo_name=repo_name, visibility="public"
)
if not success:
pytest.fail(f"Failed to change repository {repo_owner}/{repo_name} to public")
# Verify repository is now public
current_visibility = github_manager.get_repository_visibility(
repo_owner=repo_owner, repo_name=repo_name
)
logger.info(f"Repository {repo_owner}/{repo_name} visibility: {current_visibility}")
assert (
current_visibility == "public"
), f"Repository should be public, but is {current_visibility}"
# Trigger sync to update permissions
after = datetime.now(timezone.utc)
CCPairManager.sync(
cc_pair=github_cc_pair,
user_performing_action=admin_user,
)
# Wait for sync to complete with group sync
# Public repositories should be accessible to all users
CCPairManager.wait_for_sync(
cc_pair=github_cc_pair,
user_performing_action=admin_user,
after=after,
should_wait_for_group_sync=True,
timeout=900,
)
# ACL-based verification
with get_session_with_current_tenant() as db_session:
# Get all documents for this connector
all_document_ids = get_all_connector_documents(github_cc_pair, db_session)
# Test access for both users using ACL verification
accessible_docs_user1 = get_user_document_access_via_acl(
test_user=test_user_1,
document_ids=all_document_ids,
db_session=db_session,
)
accessible_docs_user2 = get_user_document_access_via_acl(
test_user=test_user_2,
document_ids=all_document_ids,
db_session=db_session,
)
logger.info(f"test_user_1 has access to {len(accessible_docs_user1)} documents")
logger.info(f"test_user_2 has access to {len(accessible_docs_user2)} documents")
# Both users should have access to the public repository documents
assert len(accessible_docs_user1) > 0, (
f"test_user_1 should have access to public repository documents. "
f"Found {len(accessible_docs_user1)} accessible docs out of "
f"{len(all_document_ids)} total"
)
assert len(accessible_docs_user2) > 0, (
f"test_user_2 should have access to public repository documents. "
f"Found {len(accessible_docs_user2)} accessible docs out of "
f"{len(all_document_ids)} total"
)
# Verify that both users get the same results (since repo is public)
assert len(accessible_docs_user1) == len(accessible_docs_user2), (
f"Both users should see the same documents from public repository. "
f"User1: {len(accessible_docs_user1)}, User2: {len(accessible_docs_user2)}"
)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission tests are enterprise only",
)
def test_github_internal_repo_permission_sync(
github_test_env_setup: GitHubTestEnvSetupTuple,
) -> None:
"""
Test that when a repository is changed to internal, test_user_1 has access but test_user_2 doesn't.
Internal repositories are accessible only to organization members.
"""
(
admin_user,
test_user_1,
test_user_2,
github_credential,
github_connector,
github_cc_pair,
) = github_test_env_setup
# Create GitHub client from credential
github_access_token = github_credential.credential_json["github_access_token"]
github_client = Github(github_access_token)
github_manager = GitHubManager(github_client)
# Get repository configuration from connector
repo_owner = github_connector.connector_specific_config["repo_owner"]
repo_name = github_connector.connector_specific_config["repositories"]
# Change repository to internal
logger.info(f"Changing repository {repo_owner}/{repo_name} to internal")
success = github_manager.change_repository_visibility(
repo_owner=repo_owner, repo_name=repo_name, visibility="internal"
)
if not success:
pytest.fail(f"Failed to change repository {repo_owner}/{repo_name} to internal")
# Verify repository is now internal
current_visibility = github_manager.get_repository_visibility(
repo_owner=repo_owner, repo_name=repo_name
)
logger.info(f"Repository {repo_owner}/{repo_name} visibility: {current_visibility}")
assert (
current_visibility == "internal"
), f"Repository should be internal, but is {current_visibility}"
# Trigger sync to update permissions
after = datetime.now(timezone.utc)
CCPairManager.sync(
cc_pair=github_cc_pair,
user_performing_action=admin_user,
)
# Wait for sync to complete with group sync
# Internal repositories should be accessible only to organization members
CCPairManager.wait_for_sync(
cc_pair=github_cc_pair,
user_performing_action=admin_user,
after=after,
should_wait_for_group_sync=True,
timeout=900,
)
# ACL-based verification
with get_session_with_current_tenant() as db_session:
# Get all documents for this connector
all_document_ids = get_all_connector_documents(github_cc_pair, db_session)
# Test access for both users using ACL verification
accessible_docs_user1 = get_user_document_access_via_acl(
test_user=test_user_1,
document_ids=all_document_ids,
db_session=db_session,
)
accessible_docs_user2 = get_user_document_access_via_acl(
test_user=test_user_2,
document_ids=all_document_ids,
db_session=db_session,
)
logger.info(f"test_user_1 has access to {len(accessible_docs_user1)} documents")
logger.info(f"test_user_2 has access to {len(accessible_docs_user2)} documents")
# For internal repositories:
# - test_user_1 should have access (assuming they're part of the organization)
# - test_user_2 should NOT have access (assuming they're not part of the organization)
assert len(accessible_docs_user1) > 0, (
f"test_user_1 should have access to internal repository documents (organization member). "
f"Found {len(accessible_docs_user1)} accessible docs out of "
f"{len(all_document_ids)} total"
)
assert len(accessible_docs_user2) == 0, (
f"test_user_2 should NOT have access to internal repository documents (not organization member). "
f"Found {len(accessible_docs_user2)} accessible docs out of "
f"{len(all_document_ids)} total"
)

Some files were not shown because too many files have changed in this diff Show More