mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-07 16:02:45 +00:00
Compare commits
1 Commits
ods/v0.5.2
...
jtahara/do
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e6dc885422 |
10
.github/workflows/pr-integration-tests.yml
vendored
10
.github/workflows/pr-integration-tests.yml
vendored
@@ -33,11 +33,6 @@ env:
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN }}
|
||||
GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC }}
|
||||
GITHUB_ADMIN_EMAIL: ${{ secrets.ONYX_GITHUB_ADMIN_EMAIL }}
|
||||
GITHUB_TEST_USER_1_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_1_EMAIL }}
|
||||
GITHUB_TEST_USER_2_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_2_EMAIL }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
@@ -404,11 +399,6 @@ jobs:
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN} \
|
||||
-e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC} \
|
||||
-e GITHUB_ADMIN_EMAIL=${GITHUB_ADMIN_EMAIL} \
|
||||
-e GITHUB_TEST_USER_1_EMAIL=${GITHUB_TEST_USER_1_EMAIL} \
|
||||
-e GITHUB_TEST_USER_2_EMAIL=${GITHUB_TEST_USER_2_EMAIL} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
|
||||
@@ -8,66 +8,30 @@ repos:
|
||||
# From: https://github.com/astral-sh/uv-pre-commit/pull/53/commits/d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
- id: uv-sync
|
||||
args: ["--active", "--locked", "--all-extras"]
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"backend",
|
||||
"-o",
|
||||
"backend/requirements/default.txt",
|
||||
]
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "backend", "-o", "backend/requirements/default.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export dev.txt
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"dev",
|
||||
"-o",
|
||||
"backend/requirements/dev.txt",
|
||||
]
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "dev", "-o", "backend/requirements/dev.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export ee.txt
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"ee",
|
||||
"-o",
|
||||
"backend/requirements/ee.txt",
|
||||
]
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "ee", "-o", "backend/requirements/ee.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export model_server.txt
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"model_server",
|
||||
"-o",
|
||||
"backend/requirements/model_server.txt",
|
||||
]
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "model_server", "-o", "backend/requirements/model_server.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
|
||||
# - id: uv-run
|
||||
# name: mypy
|
||||
@@ -76,73 +40,68 @@ repos:
|
||||
# files: ^backend/.*\.py$
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
|
||||
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
files: ^.github/
|
||||
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
|
||||
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
|
||||
hooks:
|
||||
- id: actionlint
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 8a737e727ac5ab2f1d4cf5876720ed276dc8dc4b # frozen: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
|
||||
# this is a fork which keeps compatibility with black
|
||||
- repo: https://github.com/wimglenn/reorder-python-imports-black
|
||||
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
|
||||
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
|
||||
hooks:
|
||||
- id: reorder-python-imports
|
||||
args: ["--py311-plus", "--application-directories=backend/"]
|
||||
# need to ignore alembic files, since reorder-python-imports gets confused
|
||||
# and thinks that alembic is a local package since there is a folder
|
||||
# in the backend directory called `alembic`
|
||||
exclude: ^backend/alembic/
|
||||
- id: reorder-python-imports
|
||||
args: ['--py311-plus', '--application-directories=backend/']
|
||||
# need to ignore alembic files, since reorder-python-imports gets confused
|
||||
# and thinks that alembic is a local package since there is a folder
|
||||
# in the backend directory called `alembic`
|
||||
exclude: ^backend/alembic/
|
||||
|
||||
# These settings will remove unused imports with side effects
|
||||
# Note: The repo currently does not and should not have imports with side effects
|
||||
- repo: https://github.com/PyCQA/autoflake
|
||||
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
|
||||
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
|
||||
hooks:
|
||||
- id: autoflake
|
||||
args:
|
||||
[
|
||||
"--remove-all-unused-imports",
|
||||
"--remove-unused-variables",
|
||||
"--in-place",
|
||||
"--recursive",
|
||||
]
|
||||
args: [ '--remove-all-unused-imports', '--remove-unused-variables', '--in-place' , '--recursive']
|
||||
|
||||
- repo: https://github.com/golangci/golangci-lint
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
|
||||
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
|
||||
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
language_version: system
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
language_version: system
|
||||
|
||||
- repo: https://github.com/sirwart/ripsecrets
|
||||
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
|
||||
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
|
||||
hooks:
|
||||
- id: ripsecrets
|
||||
args:
|
||||
- --additional-pattern
|
||||
- ^sk-[A-Za-z0-9_\-]{20,}$
|
||||
- --additional-pattern
|
||||
- ^sk-[A-Za-z0-9_\-]{20,}$
|
||||
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
|
||||
@@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "23957775e5f5"
|
||||
down_revision = "bc9771dccadf"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -42,13 +42,13 @@ def upgrade() -> None:
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
@@ -63,13 +63,13 @@ def upgrade() -> None:
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
|
||||
@@ -257,8 +257,8 @@ def _migrate_files_to_external_storage() -> None:
|
||||
print(f"File {file_id} not found in PostgreSQL storage.")
|
||||
continue
|
||||
|
||||
lobj_id = cast(int, file_record.lobj_oid)
|
||||
file_metadata = cast(Any, file_record.file_metadata)
|
||||
lobj_id = cast(int, file_record.lobj_oid) # type: ignore
|
||||
file_metadata = cast(Any, file_record.file_metadata) # type: ignore
|
||||
|
||||
# Read file content from PostgreSQL
|
||||
try:
|
||||
@@ -280,7 +280,7 @@ def _migrate_files_to_external_storage() -> None:
|
||||
else:
|
||||
# Convert other types to dict if possible, otherwise None
|
||||
try:
|
||||
file_metadata = dict(file_record.file_metadata)
|
||||
file_metadata = dict(file_record.file_metadata) # type: ignore
|
||||
except (TypeError, ValueError):
|
||||
file_metadata = None
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
|
||||
revision = "e209dc5a8156"
|
||||
down_revision = "48d14957fe80"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -8,7 +8,7 @@ Create Date: 2025-11-28 11:15:37.667340
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from onyx.db.enums import (
|
||||
from onyx.db.enums import ( # type: ignore[import-untyped]
|
||||
MCPTransport,
|
||||
MCPAuthenticationType,
|
||||
MCPAuthenticationPerformer,
|
||||
|
||||
@@ -82,9 +82,9 @@ def run_migrations_offline() -> None:
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore[arg-type]
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
@@ -118,6 +118,6 @@ def fetch_document_sets(
|
||||
.all()
|
||||
)
|
||||
|
||||
document_set_with_cc_pairs.append((document_set, cc_pairs))
|
||||
document_set_with_cc_pairs.append((document_set, cc_pairs)) # type: ignore
|
||||
|
||||
return document_set_with_cc_pairs
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fastapi import APIRouter
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
|
||||
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
@@ -36,8 +36,8 @@ from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from setfit import SetFitModel # type: ignore[import-untyped]
|
||||
from transformers import PreTrainedTokenizer, BatchEncoding
|
||||
from setfit import SetFitModel # type: ignore
|
||||
from transformers import PreTrainedTokenizer, BatchEncoding # type: ignore
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -42,7 +42,7 @@ def get_embedding_model(
|
||||
Loads or returns a cached SentenceTransformer, sets max_seq_length, pins device,
|
||||
pre-warms rotary caches once, and wraps encode() with a lock to avoid cache races.
|
||||
"""
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
def _prewarm_rope(st_model: "SentenceTransformer", target_len: int) -> None:
|
||||
"""
|
||||
@@ -91,7 +91,7 @@ def get_local_reranking_model(
|
||||
model_name: str,
|
||||
) -> "CrossEncoder":
|
||||
global _RERANK_MODEL
|
||||
from sentence_transformers import CrossEncoder
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
|
||||
if _RERANK_MODEL is None:
|
||||
logger.notice(f"Loading {model_name}")
|
||||
@@ -195,7 +195,7 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
|
||||
# Run CPU-bound reranking in a thread pool
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(),
|
||||
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from fastapi import FastAPI
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from transformers import logging as transformer_logging
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_information_content_model
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import DistilBertConfig
|
||||
from transformers import DistilBertConfig # type: ignore
|
||||
|
||||
|
||||
class HybridClassifier(nn.Module):
|
||||
@@ -34,7 +34,7 @@ class HybridClassifier(nn.Module):
|
||||
query_ids: torch.Tensor,
|
||||
query_mask: torch.Tensor,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask)
|
||||
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore
|
||||
sequence_output = outputs.last_hidden_state
|
||||
|
||||
# Intent classification on the CLS token
|
||||
@@ -102,7 +102,7 @@ class ConnectorClassifier(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self.distilbert(
|
||||
hidden_states = self.distilbert( # type: ignore
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).last_hidden_state
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ def get_access_for_document(
|
||||
versioned_get_access_for_document_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "_get_access_for_document"
|
||||
)
|
||||
return versioned_get_access_for_document_fn(document_id, db_session)
|
||||
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
|
||||
|
||||
|
||||
def get_null_document_access() -> DocumentAccess:
|
||||
@@ -93,7 +93,9 @@ def get_access_for_documents(
|
||||
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "_get_access_for_documents"
|
||||
)
|
||||
return versioned_get_access_for_documents_fn(document_ids, db_session)
|
||||
return versioned_get_access_for_documents_fn(
|
||||
document_ids, db_session
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
@@ -111,7 +113,7 @@ def get_acl_for_user(user: User | None, db_session: Session | None = None) -> se
|
||||
versioned_acl_for_user_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "_get_acl_for_user"
|
||||
)
|
||||
return versioned_acl_for_user_fn(user, db_session)
|
||||
return versioned_acl_for_user_fn(user, db_session) # type: ignore
|
||||
|
||||
|
||||
def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> bool:
|
||||
|
||||
@@ -338,7 +338,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
user_created = False
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request)
|
||||
user = await super().create(
|
||||
user_create, safe=safe, request=request
|
||||
) # type: ignore
|
||||
user_created = True
|
||||
except IntegrityError as error:
|
||||
# Race condition: another request created the same user after the
|
||||
@@ -602,7 +604,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
if user.oidc_expiry is not None and not TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
if (
|
||||
user.oidc_expiry is not None # type: ignore
|
||||
and not TRACK_EXTERNAL_IDP_EXPIRY
|
||||
):
|
||||
await self.user_db.update(user, {"oidc_expiry": None})
|
||||
user.oidc_expiry = None # type: ignore
|
||||
remove_user_from_invited_users(user.email)
|
||||
@@ -1173,7 +1178,7 @@ async def _sync_jwt_oidc_expiry(
|
||||
return
|
||||
|
||||
await user_manager.user_db.update(user, {"oidc_expiry": oidc_expiry})
|
||||
user.oidc_expiry = oidc_expiry
|
||||
user.oidc_expiry = oidc_expiry # type: ignore
|
||||
return
|
||||
|
||||
if user.oidc_expiry is not None:
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import SearchSettings
|
||||
|
||||
|
||||
def try_creating_docfetching_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger indexing immediately.
|
||||
|
||||
Now uses database-based coordination instead of Redis fencing.
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
index_attempt_id = None
|
||||
try:
|
||||
# Basic status checks
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# Generate custom task ID for tracking
|
||||
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
|
||||
|
||||
# Try to create a new index attempt using database coordination
|
||||
# This replaces the Redis fencing mechanism
|
||||
index_attempt_id = IndexingCoordination.try_create_index_attempt(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
celery_task_id=custom_task_id,
|
||||
from_beginning=reindex,
|
||||
)
|
||||
|
||||
if index_attempt_id is None:
|
||||
# Another indexing attempt is already running
|
||||
return None
|
||||
|
||||
# Determine which queue to use based on whether this is a user file
|
||||
# TODO: at the moment the indexing pipeline is
|
||||
# shared between user files and connectors
|
||||
queue = (
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
if cc_pair.is_user_file
|
||||
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
|
||||
)
|
||||
|
||||
# Use higher priority for first-time indexing to ensure new connectors
|
||||
# get processed before re-indexing of existing connectors
|
||||
has_successful_attempt = cc_pair.last_successful_index_time is not None
|
||||
priority = (
|
||||
OnyxCeleryPriority.MEDIUM
|
||||
if has_successful_attempt
|
||||
else OnyxCeleryPriority.HIGH
|
||||
)
|
||||
|
||||
# Send the task to Celery
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=queue,
|
||||
task_id=custom_task_id,
|
||||
priority=priority,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
|
||||
|
||||
task_logger.info(
|
||||
f"Created docfetching task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id} "
|
||||
f"attempt_id={index_attempt_id} "
|
||||
f"celery_task_id={custom_task_id}"
|
||||
)
|
||||
|
||||
return index_attempt_id
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
# Clean up on failure
|
||||
if index_attempt_id is not None:
|
||||
mark_attempt_failed(index_attempt_id, db_session)
|
||||
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return index_attempt_id
|
||||
@@ -25,14 +25,14 @@ from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.celery.tasks.docfetching.task_creation_utils import (
|
||||
try_creating_docfetching_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import stop_heartbeat
|
||||
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.docprocessing.utils import is_in_repeated_error_state
|
||||
from onyx.background.celery.tasks.docprocessing.utils import should_index
|
||||
from onyx.background.celery.tasks.docprocessing.utils import (
|
||||
try_creating_docfetching_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.models import DocProcessingContext
|
||||
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
|
||||
from onyx.background.indexing.checkpointing_utils import (
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -17,6 +24,8 @@ from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -289,3 +298,112 @@ def should_index(
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_docfetching_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger indexing immediately.
|
||||
|
||||
Now uses database-based coordination instead of Redis fencing.
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
index_attempt_id = None
|
||||
try:
|
||||
# Basic status checks
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# Generate custom task ID for tracking
|
||||
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
|
||||
|
||||
# Try to create a new index attempt using database coordination
|
||||
# This replaces the Redis fencing mechanism
|
||||
index_attempt_id = IndexingCoordination.try_create_index_attempt(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
celery_task_id=custom_task_id,
|
||||
from_beginning=reindex,
|
||||
)
|
||||
|
||||
if index_attempt_id is None:
|
||||
# Another indexing attempt is already running
|
||||
return None
|
||||
|
||||
# Determine which queue to use based on whether this is a user file
|
||||
# TODO: at the moment the indexing pipeline is
|
||||
# shared between user files and connectors
|
||||
queue = (
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
if cc_pair.is_user_file
|
||||
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
|
||||
)
|
||||
|
||||
# Send the task to Celery
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=queue,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
|
||||
|
||||
task_logger.info(
|
||||
f"Created docfetching task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id} "
|
||||
f"attempt_id={index_attempt_id} "
|
||||
f"celery_task_id={custom_task_id}"
|
||||
)
|
||||
|
||||
return index_attempt_id
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
# Clean up on failure
|
||||
if index_attempt_id is not None:
|
||||
mark_attempt_failed(index_attempt_id, db_session)
|
||||
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return index_attempt_id
|
||||
|
||||
@@ -368,19 +368,11 @@ def connector_document_extraction(
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
|
||||
|
||||
from_beginning = index_attempt.from_beginning
|
||||
has_successful_attempt = (
|
||||
index_attempt.connector_credential_pair.last_successful_index_time
|
||||
is not None
|
||||
)
|
||||
# Use higher priority for first-time indexing to ensure new connectors
|
||||
# get processed before re-indexing of existing connectors
|
||||
docprocessing_priority = (
|
||||
OnyxCeleryPriority.MEDIUM
|
||||
if has_successful_attempt
|
||||
else OnyxCeleryPriority.HIGH
|
||||
)
|
||||
|
||||
earliest_index_time = (
|
||||
db_connector.indexing_start.timestamp()
|
||||
@@ -503,7 +495,6 @@ def connector_document_extraction(
|
||||
tenant_id,
|
||||
app,
|
||||
most_recent_attempt,
|
||||
docprocessing_priority,
|
||||
)
|
||||
last_batch_num = reissued_batch_count + completed_batches
|
||||
index_attempt.completed_batches = completed_batches
|
||||
@@ -616,7 +607,7 @@ def connector_document_extraction(
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
kwargs=processing_batch_data,
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=docprocessing_priority,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
@@ -767,7 +758,6 @@ def reissue_old_batches(
|
||||
tenant_id: str,
|
||||
app: Celery,
|
||||
most_recent_attempt: IndexAttempt | None,
|
||||
priority: OnyxCeleryPriority,
|
||||
) -> tuple[int, int]:
|
||||
# When loading from a checkpoint, we need to start new docprocessing tasks
|
||||
# tied to the new index attempt for any batches left over in the file store
|
||||
@@ -795,7 +785,7 @@ def reissue_old_batches(
|
||||
"batch_num": path_info.batch_num, # use same batch num as previously
|
||||
},
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=priority,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
recent_batches = most_recent_attempt.completed_batches if most_recent_attempt else 0
|
||||
# resume from the batch num of the last attempt. This should be one more
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
@@ -87,7 +86,7 @@ class ChatStateContainer:
|
||||
return self.is_clarification
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
def run_chat_llm_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
@@ -111,7 +110,7 @@ def run_chat_loop_with_state_containers(
|
||||
**kwargs: Additional keyword arguments for func
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
packets = run_chat_llm_with_state_containers(
|
||||
my_func,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
@@ -132,7 +131,7 @@ def run_chat_loop_with_state_containers(
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
turn_index=0,
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -27,12 +27,11 @@ from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.utils import model_needs_formatting_reenabled
|
||||
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
|
||||
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
@@ -444,7 +443,7 @@ def run_llm_loop(
|
||||
tool_definitions=[tool.tool_definition() for tool in final_tools],
|
||||
tool_choice=tool_choice,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
citation_processor=citation_processor,
|
||||
state_container=state_container,
|
||||
# The rich docs representation is passed in so that when yielding the answer, it can also
|
||||
@@ -496,7 +495,7 @@ def run_llm_loop(
|
||||
raise ValueError("Tool response missing tool_call reference")
|
||||
|
||||
tool_call = tool_response.tool_call
|
||||
tab_index = tool_call.placement.tab_index
|
||||
tab_index = tool_call.tab_index
|
||||
|
||||
# Track if search tool was called (for skipping query expansion on subsequent calls)
|
||||
if tool_call.tool_name == SearchTool.NAME:
|
||||
@@ -626,7 +625,7 @@ def run_llm_loop(
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
obj=OverallStop(type="stop"),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -38,7 +38,6 @@ from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDone
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
@@ -199,7 +198,6 @@ def _update_tool_call_with_delta(
|
||||
def _extract_tool_call_kickoffs(
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]],
|
||||
turn_index: int,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> list[ToolCallKickoff]:
|
||||
"""Extract ToolCallKickoff objects from the tool call map.
|
||||
|
||||
@@ -224,11 +222,8 @@ def _extract_tool_call_kickoffs(
|
||||
tool_call_id=tool_call_data["id"],
|
||||
tool_name=tool_call_data["name"],
|
||||
tool_args=tool_args,
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
)
|
||||
)
|
||||
tab_index += 1
|
||||
@@ -380,21 +375,12 @@ def translate_history_to_llm_format(
|
||||
return messages
|
||||
|
||||
|
||||
def _increment_turns(
|
||||
turn_index: int, sub_turn_index: int | None
|
||||
) -> tuple[int, int | None]:
|
||||
if sub_turn_index is None:
|
||||
return turn_index + 1, None
|
||||
else:
|
||||
return turn_index, sub_turn_index + 1
|
||||
|
||||
|
||||
def run_llm_step_pkt_generator(
|
||||
history: list[ChatMessageSimple],
|
||||
tool_definitions: list[dict],
|
||||
tool_choice: ToolChoiceOptions,
|
||||
llm: LLM,
|
||||
placement: Placement,
|
||||
turn_index: int,
|
||||
state_container: ChatStateContainer,
|
||||
citation_processor: DynamicCitationProcessor | None,
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
@@ -403,58 +389,9 @@ def run_llm_step_pkt_generator(
|
||||
custom_token_processor: (
|
||||
Callable[[Delta | None, Any], tuple[Delta | None, Any]] | None
|
||||
) = None,
|
||||
) -> Generator[Packet, None, tuple[LlmStepResult, bool]]:
|
||||
"""Run an LLM step and stream the response as packets.
|
||||
NOTE: DO NOT TOUCH THIS FUNCTION BEFORE ASKING YUHONG, this is very finicky and
|
||||
delicate logic that is core to the app's main functionality.
|
||||
|
||||
This generator function streams LLM responses, processing reasoning content,
|
||||
answer content, tool calls, and citations. It yields Packet objects for
|
||||
real-time streaming to clients and accumulates the final result.
|
||||
|
||||
Args:
|
||||
history: List of chat messages in the conversation history.
|
||||
tool_definitions: List of tool definitions available to the LLM.
|
||||
tool_choice: Tool choice configuration (e.g., "auto", "required", "none").
|
||||
llm: Language model interface to use for generation.
|
||||
turn_index: Current turn index in the conversation.
|
||||
state_container: Container for storing chat state (reasoning, answers).
|
||||
citation_processor: Optional processor for extracting and formatting citations
|
||||
from the response. If provided, processes tokens to identify citations.
|
||||
reasoning_effort: Optional reasoning effort configuration for models that
|
||||
support reasoning (e.g., o1 models).
|
||||
final_documents: Optional list of search documents to include in the response
|
||||
start packet.
|
||||
user_identity: Optional user identity information for the LLM.
|
||||
custom_token_processor: Optional callable that processes each token delta
|
||||
before yielding. Receives (delta, processor_state) and returns
|
||||
(modified_delta, new_processor_state). Can return None for delta to skip.
|
||||
sub_turn_index: Optional sub-turn index for nested tool/agent calls.
|
||||
|
||||
Yields:
|
||||
Packet: Streaming packets containing:
|
||||
- ReasoningStart/ReasoningDelta/ReasoningDone for reasoning content
|
||||
- AgentResponseStart/AgentResponseDelta for answer content
|
||||
- CitationInfo for extracted citations
|
||||
- ToolCallKickoff for tool calls (extracted at the end)
|
||||
|
||||
Returns:
|
||||
tuple[LlmStepResult, bool]: A tuple containing:
|
||||
- LlmStepResult: The final result with accumulated reasoning, answer,
|
||||
and tool calls (if any).
|
||||
- bool: Whether reasoning occurred during this step. This should be used to
|
||||
increment the turn index or sub_turn index for the rest of the LLM loop.
|
||||
|
||||
Note:
|
||||
The function handles incremental state updates, saving reasoning and answer
|
||||
tokens to the state container as they are generated. Tool calls are extracted
|
||||
and yielded only after the stream completes.
|
||||
"""
|
||||
|
||||
turn_index = placement.turn_index
|
||||
tab_index = placement.tab_index
|
||||
sub_turn_index = placement.sub_turn_index
|
||||
|
||||
) -> Generator[Packet, None, tuple[LlmStepResult, int]]:
|
||||
# The second return value is for the turn index because reasoning counts on the frontend as a turn
|
||||
# TODO this is maybe ok but does not align well with the backend logic too well
|
||||
llm_msg_history = translate_history_to_llm_format(history)
|
||||
has_reasoned = 0
|
||||
|
||||
@@ -519,19 +456,11 @@ def run_llm_step_pkt_generator(
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDelta(reasoning=delta.reasoning_content),
|
||||
)
|
||||
reasoning_start = True
|
||||
@@ -539,26 +468,15 @@ def run_llm_step_pkt_generator(
|
||||
if delta.content:
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
|
||||
if not answer_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
turn_index=turn_index + has_reasoned,
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
),
|
||||
@@ -572,20 +490,12 @@ def run_llm_step_pkt_generator(
|
||||
# Save answer incrementally to state container
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
turn_index=turn_index + has_reasoned,
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
turn_index=turn_index + has_reasoned,
|
||||
obj=result,
|
||||
)
|
||||
else:
|
||||
@@ -594,28 +504,17 @@ def run_llm_step_pkt_generator(
|
||||
# Save answer incrementally to state container
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
turn_index=turn_index + has_reasoned,
|
||||
obj=AgentResponseDelta(content=delta.content),
|
||||
)
|
||||
|
||||
if delta.tool_calls:
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
@@ -629,7 +528,7 @@ def run_llm_step_pkt_generator(
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
|
||||
tool_calls = _extract_tool_call_kickoffs(
|
||||
id_to_tool_call_map, turn_index, sub_turn_index
|
||||
id_to_tool_call_map, turn_index + has_reasoned
|
||||
)
|
||||
if tool_calls:
|
||||
tool_calls_list: list[ToolCall] = [
|
||||
@@ -657,16 +556,15 @@ def run_llm_step_pkt_generator(
|
||||
tool_calls=None,
|
||||
)
|
||||
span_generation.span_data.output = [assistant_msg_no_tools.model_dump()]
|
||||
|
||||
# Should have closed the reasoning block, the only pathway to hit this is if the stream
|
||||
# ended with reasoning content and no other content. This is an invalid state.
|
||||
# Close reasoning block if still open (stream ended with reasoning content)
|
||||
if reasoning_start:
|
||||
raise RuntimeError("Reasoning block is still open but the stream ended.")
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
|
||||
# Flush any remaining content from citation processor
|
||||
# Reasoning is always first so this should use the post-incremented value of turn_index
|
||||
# Note that this doesn't need to handle any sub-turns as those docs will not have citations
|
||||
# as clickable items and will be stripped out instead.
|
||||
if citation_processor:
|
||||
for result in citation_processor.process_token(None):
|
||||
if isinstance(result, str):
|
||||
@@ -674,20 +572,12 @@ def run_llm_step_pkt_generator(
|
||||
# Save answer incrementally to state container
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
turn_index=turn_index + has_reasoned,
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
turn_index=turn_index + has_reasoned,
|
||||
obj=result,
|
||||
)
|
||||
|
||||
@@ -722,7 +612,7 @@ def run_llm_step(
|
||||
tool_definitions: list[dict],
|
||||
tool_choice: ToolChoiceOptions,
|
||||
llm: LLM,
|
||||
placement: Placement,
|
||||
turn_index: int,
|
||||
state_container: ChatStateContainer,
|
||||
citation_processor: DynamicCitationProcessor | None,
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
@@ -742,7 +632,7 @@ def run_llm_step(
|
||||
tool_definitions=tool_definitions,
|
||||
tool_choice=tool_choice,
|
||||
llm=llm,
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
state_container=state_container,
|
||||
citation_processor=citation_processor,
|
||||
reasoning_effort=reasoning_effort,
|
||||
|
||||
@@ -8,7 +8,7 @@ from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import run_chat_loop_with_state_containers
|
||||
from onyx.chat.chat_state import run_chat_llm_with_state_containers
|
||||
from onyx.chat.chat_utils import convert_chat_history
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.chat_utils import get_custom_agent_prompt
|
||||
@@ -65,7 +65,7 @@ from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.tools.constants import SEARCH_TOOL_ID
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import CustomToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
@@ -553,7 +553,7 @@ def stream_chat_message_objects(
|
||||
# (user has already responded to a clarification question)
|
||||
skip_clarification = is_last_assistant_message_clarification(chat_history)
|
||||
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
yield from run_chat_llm_with_state_containers(
|
||||
run_deep_research_llm_loop,
|
||||
is_connected=check_is_connected,
|
||||
emitter=emitter,
|
||||
@@ -568,7 +568,7 @@ def stream_chat_message_objects(
|
||||
user_identity=user_identity,
|
||||
)
|
||||
else:
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
yield from run_chat_llm_with_state_containers(
|
||||
run_llm_loop,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
emitter=emitter,
|
||||
|
||||
@@ -22,7 +22,7 @@ from onyx.prompts.tool_prompts import PYTHON_TOOL_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
|
||||
@@ -583,16 +583,6 @@ LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8)
|
||||
MAX_SLACK_QUERY_EXPANSIONS = int(os.environ.get("MAX_SLACK_QUERY_EXPANSIONS", "5"))
|
||||
|
||||
# Slack federated search thread context settings
|
||||
# Batch size for fetching thread context (controls concurrent API calls per batch)
|
||||
SLACK_THREAD_CONTEXT_BATCH_SIZE = int(
|
||||
os.environ.get("SLACK_THREAD_CONTEXT_BATCH_SIZE", "5")
|
||||
)
|
||||
# Maximum messages to fetch thread context for (top N by relevance get full context)
|
||||
MAX_SLACK_THREAD_CONTEXT_MESSAGES = int(
|
||||
os.environ.get("MAX_SLACK_THREAD_CONTEXT_MESSAGES", "5")
|
||||
)
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
@@ -708,15 +698,6 @@ AVERAGE_SUMMARY_EMBEDDINGS = (
|
||||
|
||||
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
|
||||
|
||||
# The intent was to have this be configurable per query, but I don't think any
|
||||
# codepath was actually configuring this, so for the migrated Vespa interface
|
||||
# we'll just use the default value, but also have it be configurable by env var.
|
||||
RECENCY_BIAS_MULTIPLIER = float(os.environ.get("RECENCY_BIAS_MULTIPLIER") or 1.0)
|
||||
|
||||
# Should match the rerank-count value set in
|
||||
# backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd.jinja.
|
||||
RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
|
||||
|
||||
|
||||
#####
|
||||
# Tool Configs
|
||||
|
||||
@@ -563,7 +563,7 @@ REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
if platform.system() == "Darwin":
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
|
||||
else:
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore
|
||||
|
||||
|
||||
class OnyxCallTypes(str, Enum):
|
||||
|
||||
@@ -38,7 +38,7 @@ class AsanaAPI:
|
||||
def __init__(
|
||||
self, api_token: str, workspace_gid: str, team_gid: str | None
|
||||
) -> None:
|
||||
self._user = None
|
||||
self._user = None # type: ignore
|
||||
self.workspace_gid = workspace_gid
|
||||
self.team_gid = team_gid
|
||||
|
||||
|
||||
@@ -9,14 +9,14 @@ from typing import Any
|
||||
from typing import Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
import boto3 # type: ignore
|
||||
from botocore.client import Config # type: ignore
|
||||
from botocore.credentials import RefreshableCredentials
|
||||
from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import NoCredentialsError
|
||||
from botocore.exceptions import PartialCredentialsError
|
||||
from botocore.session import get_session
|
||||
from mypy_boto3_s3 import S3Client
|
||||
from mypy_boto3_s3 import S3Client # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import BLOB_STORAGE_SIZE_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
|
||||
@@ -2,11 +2,11 @@ from datetime import timezone
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from dropbox import Dropbox # type: ignore[import-untyped]
|
||||
from dropbox.exceptions import ApiError # type: ignore[import-untyped]
|
||||
from dropbox.exceptions import AuthError
|
||||
from dropbox.files import FileMetadata # type: ignore[import-untyped]
|
||||
from dropbox.files import FolderMetadata
|
||||
from dropbox import Dropbox # type: ignore
|
||||
from dropbox.exceptions import ApiError # type:ignore
|
||||
from dropbox.exceptions import AuthError # type:ignore
|
||||
from dropbox.files import FileMetadata # type:ignore
|
||||
from dropbox.files import FolderMetadata # type:ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
|
||||
@@ -14,9 +14,9 @@ from typing import cast
|
||||
from typing import Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from google.auth.exceptions import RefreshError
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from google.auth.exceptions import RefreshError # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -1006,7 +1006,7 @@ class GoogleDriveConnector(
|
||||
file.user_email,
|
||||
)
|
||||
if file.error is None:
|
||||
file.error = exc
|
||||
file.error = exc # type: ignore[assignment]
|
||||
yield file
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
@@ -4,7 +4,7 @@ from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -179,7 +179,7 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
)
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from google.auth.exceptions import RefreshError
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.discovery import build # type: ignore[import-untyped]
|
||||
from googleapiclient.discovery import Resource
|
||||
from google.auth.exceptions import RefreshError # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from urllib.parse import urlparse
|
||||
from urllib.parse import urlunparse
|
||||
|
||||
from pywikibot import family # type: ignore[import-untyped]
|
||||
from pywikibot import pagegenerators
|
||||
from pywikibot import pagegenerators # type: ignore[import-untyped]
|
||||
from pywikibot.scripts import generate_family_file # type: ignore[import-untyped]
|
||||
from pywikibot.scripts.generate_user_files import pywikibot # type: ignore[import-untyped]
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from typing import cast
|
||||
from typing import ClassVar
|
||||
|
||||
import pywikibot.time # type: ignore[import-untyped]
|
||||
from pywikibot import pagegenerators
|
||||
from pywikibot import textlib
|
||||
from pywikibot import pagegenerators # type: ignore[import-untyped]
|
||||
from pywikibot import textlib # type: ignore[import-untyped]
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
@@ -4,9 +4,9 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.teams.channels.channel import Channel # type: ignore[import-untyped]
|
||||
from office365.teams.channels.channel import ConversationMember
|
||||
from office365.graph_client import GraphClient # type: ignore
|
||||
from office365.teams.channels.channel import Channel # type: ignore
|
||||
from office365.teams.channels.channel import ConversationMember # type: ignore
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
|
||||
@@ -21,13 +21,6 @@ class OptionalSearchSetting(str, Enum):
|
||||
|
||||
|
||||
class QueryType(str, Enum):
|
||||
"""
|
||||
The type of first-pass query to use for hybrid search.
|
||||
|
||||
The values of this enum are injected into the ranking profile name which
|
||||
should match the name in the schema.
|
||||
"""
|
||||
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
|
||||
|
||||
@@ -13,8 +13,6 @@ from slack_sdk.errors import SlackApiError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
|
||||
from onyx.configs.app_configs import MAX_SLACK_THREAD_CONTEXT_MESSAGES
|
||||
from onyx.configs.app_configs import SLACK_THREAD_CONTEXT_BATCH_SIZE
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
@@ -625,55 +623,33 @@ def merge_slack_messages(
|
||||
return merged_messages, docid_to_message, all_filtered_channels
|
||||
|
||||
|
||||
class SlackRateLimitError(Exception):
|
||||
"""Raised when Slack API returns a rate limit error (429)."""
|
||||
|
||||
|
||||
class ThreadContextResult:
|
||||
"""Result wrapper for thread context fetch that captures error type."""
|
||||
|
||||
__slots__ = ("text", "is_rate_limited", "is_error")
|
||||
|
||||
def __init__(
|
||||
self, text: str, is_rate_limited: bool = False, is_error: bool = False
|
||||
):
|
||||
self.text = text
|
||||
self.is_rate_limited = is_rate_limited
|
||||
self.is_error = is_error
|
||||
|
||||
@classmethod
|
||||
def success(cls, text: str) -> "ThreadContextResult":
|
||||
return cls(text)
|
||||
|
||||
@classmethod
|
||||
def rate_limited(cls, original_text: str) -> "ThreadContextResult":
|
||||
return cls(original_text, is_rate_limited=True)
|
||||
|
||||
@classmethod
|
||||
def error(cls, original_text: str) -> "ThreadContextResult":
|
||||
return cls(original_text, is_error=True)
|
||||
|
||||
|
||||
def _fetch_thread_context(
|
||||
def get_contextualized_thread_text(
|
||||
message: SlackMessage, access_token: str, team_id: str | None = None
|
||||
) -> ThreadContextResult:
|
||||
) -> str:
|
||||
"""
|
||||
Fetch thread context for a message, returning a result object.
|
||||
Retrieves the initial thread message as well as the text following the message
|
||||
and combines them into a single string. If the slack query fails, returns the
|
||||
original message text.
|
||||
|
||||
Returns ThreadContextResult with:
|
||||
- success: enriched thread text
|
||||
- rate_limited: original text + flag indicating we should stop
|
||||
- error: original text for other failures (graceful degradation)
|
||||
The idea is that the message (the one that actually matched the search), the
|
||||
initial thread message, and the replies to the message are important in answering
|
||||
the user's query.
|
||||
|
||||
Args:
|
||||
message: The SlackMessage to get context for
|
||||
access_token: Slack OAuth access token
|
||||
team_id: Slack team ID for caching user profiles (optional but recommended)
|
||||
"""
|
||||
channel_id = message.channel_id
|
||||
thread_id = message.thread_id
|
||||
message_id = message.message_id
|
||||
|
||||
# If not a thread, return original text as success
|
||||
# if it's not a thread, return the message text
|
||||
if thread_id is None:
|
||||
return ThreadContextResult.success(message.text)
|
||||
return message.text
|
||||
|
||||
slack_client = WebClient(token=access_token, timeout=30)
|
||||
# get the thread messages
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
response = slack_client.conversations_replies(
|
||||
channel=channel_id,
|
||||
@@ -682,44 +658,19 @@ def _fetch_thread_context(
|
||||
response.validate()
|
||||
messages: list[dict[str, Any]] = response.get("messages", [])
|
||||
except SlackApiError as e:
|
||||
# Check for rate limit error specifically
|
||||
if e.response and e.response.status_code == 429:
|
||||
logger.warning(
|
||||
f"Slack rate limit hit while fetching thread context for {channel_id}/{thread_id}"
|
||||
)
|
||||
return ThreadContextResult.rate_limited(message.text)
|
||||
# For other Slack errors, log and return original text
|
||||
logger.error(f"Slack API error in thread context fetch: {e}")
|
||||
return ThreadContextResult.error(message.text)
|
||||
except Exception as e:
|
||||
# Network errors, timeouts, etc - treat as recoverable error
|
||||
logger.error(f"Unexpected error in thread context fetch: {e}")
|
||||
return ThreadContextResult.error(message.text)
|
||||
logger.error(f"Slack API error in get_contextualized_thread_text: {e}")
|
||||
return message.text
|
||||
|
||||
# If empty response or single message (not a thread), return original text
|
||||
# make sure we didn't get an empty response or a single message (not a thread)
|
||||
if len(messages) <= 1:
|
||||
return ThreadContextResult.success(message.text)
|
||||
return message.text
|
||||
|
||||
# Build thread text from thread starter + context window around matched message
|
||||
thread_text = _build_thread_text(
|
||||
messages, message_id, thread_id, access_token, team_id, slack_client
|
||||
)
|
||||
return ThreadContextResult.success(thread_text)
|
||||
|
||||
|
||||
def _build_thread_text(
|
||||
messages: list[dict[str, Any]],
|
||||
message_id: str,
|
||||
thread_id: str,
|
||||
access_token: str,
|
||||
team_id: str | None,
|
||||
slack_client: WebClient,
|
||||
) -> str:
|
||||
"""Build the thread text from messages."""
|
||||
# add the initial thread message
|
||||
msg_text = messages[0].get("text", "")
|
||||
msg_sender = messages[0].get("user", "")
|
||||
thread_text = f"<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# add the message (unless it's the initial message)
|
||||
thread_text += "\n\nReplies:"
|
||||
if thread_id == message_id:
|
||||
message_id_idx = 0
|
||||
@@ -730,21 +681,28 @@ def _build_thread_text(
|
||||
if not message_id_idx:
|
||||
return thread_text
|
||||
|
||||
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
|
||||
# Include a few messages BEFORE the matched message for context
|
||||
# This helps understand what the matched message is responding to
|
||||
start_idx = max(
|
||||
1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW
|
||||
) # Start after thread starter
|
||||
|
||||
# Add ellipsis if we're skipping messages between thread starter and context window
|
||||
if start_idx > 1:
|
||||
thread_text += "\n..."
|
||||
|
||||
# Add context messages before the matched message
|
||||
for i in range(start_idx, message_id_idx):
|
||||
msg_text = messages[i].get("text", "")
|
||||
msg_sender = messages[i].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Add the matched message itself
|
||||
msg_text = messages[message_id_idx].get("text", "")
|
||||
msg_sender = messages[message_id_idx].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Add following replies
|
||||
# add the following replies to the thread text
|
||||
len_replies = 0
|
||||
for msg in messages[message_id_idx + 1 :]:
|
||||
msg_text = msg.get("text", "")
|
||||
@@ -752,19 +710,22 @@ def _build_thread_text(
|
||||
reply = f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
thread_text += reply
|
||||
|
||||
# stop if len_replies exceeds chunk_size * 4 chars as the rest likely won't fit
|
||||
len_replies += len(reply)
|
||||
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
|
||||
thread_text += "\n..."
|
||||
break
|
||||
|
||||
# Replace user IDs with names using cached lookups
|
||||
# replace user ids with names in the thread text using cached lookups
|
||||
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
|
||||
|
||||
if team_id:
|
||||
# Use cached batch lookup when team_id is available
|
||||
user_profiles = batch_get_user_profiles(access_token, team_id, userids)
|
||||
for userid, name in user_profiles.items():
|
||||
thread_text = thread_text.replace(f"<@{userid}>", name)
|
||||
else:
|
||||
# Fallback to individual lookups (no caching) when team_id not available
|
||||
for userid in userids:
|
||||
try:
|
||||
response = slack_client.users_profile_get(user=userid)
|
||||
@@ -774,7 +735,7 @@ def _build_thread_text(
|
||||
except SlackApiError as e:
|
||||
if "user_not_found" in str(e):
|
||||
logger.debug(
|
||||
f"User {userid} not found (likely deleted/deactivated)"
|
||||
f"User {userid} not found in Slack workspace (likely deleted/deactivated)"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Could not fetch profile for user {userid}: {e}")
|
||||
@@ -786,115 +747,6 @@ def _build_thread_text(
|
||||
return thread_text
|
||||
|
||||
|
||||
def fetch_thread_contexts_with_rate_limit_handling(
|
||||
slack_messages: list[SlackMessage],
|
||||
access_token: str,
|
||||
team_id: str | None,
|
||||
batch_size: int = SLACK_THREAD_CONTEXT_BATCH_SIZE,
|
||||
max_messages: int | None = MAX_SLACK_THREAD_CONTEXT_MESSAGES,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Fetch thread contexts in controlled batches, stopping on rate limit.
|
||||
|
||||
Distinguishes between error types:
|
||||
- Rate limit (429): Stop processing further batches
|
||||
- Other errors: Continue processing (graceful degradation)
|
||||
|
||||
Args:
|
||||
slack_messages: Messages to fetch thread context for (should be sorted by relevance)
|
||||
access_token: Slack OAuth token
|
||||
team_id: Slack team ID for user profile caching
|
||||
batch_size: Number of concurrent API calls per batch
|
||||
max_messages: Maximum messages to fetch thread context for (None = no limit)
|
||||
|
||||
Returns:
|
||||
List of thread texts, one per input message.
|
||||
Messages beyond max_messages or after rate limit get their original text.
|
||||
"""
|
||||
if not slack_messages:
|
||||
return []
|
||||
|
||||
# Limit how many messages we fetch thread context for (if max_messages is set)
|
||||
if max_messages and max_messages < len(slack_messages):
|
||||
messages_for_context = slack_messages[:max_messages]
|
||||
messages_without_context = slack_messages[max_messages:]
|
||||
else:
|
||||
messages_for_context = slack_messages
|
||||
messages_without_context = []
|
||||
|
||||
logger.info(
|
||||
f"Fetching thread context for {len(messages_for_context)} of {len(slack_messages)} messages "
|
||||
f"(batch_size={batch_size}, max={max_messages or 'unlimited'})"
|
||||
)
|
||||
|
||||
results: list[str] = []
|
||||
rate_limited = False
|
||||
total_batches = (len(messages_for_context) + batch_size - 1) // batch_size
|
||||
rate_limit_batch = 0
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(messages_for_context), batch_size):
|
||||
current_batch = i // batch_size + 1
|
||||
|
||||
if rate_limited:
|
||||
# Skip remaining batches, use original message text
|
||||
remaining = messages_for_context[i:]
|
||||
skipped_batches = total_batches - rate_limit_batch
|
||||
logger.warning(
|
||||
f"Slack rate limit: skipping {len(remaining)} remaining messages "
|
||||
f"({skipped_batches} of {total_batches} batches). "
|
||||
f"Successfully enriched {len(results)} messages before rate limit."
|
||||
)
|
||||
results.extend([msg.text for msg in remaining])
|
||||
break
|
||||
|
||||
batch = messages_for_context[i : i + batch_size]
|
||||
|
||||
# _fetch_thread_context returns ThreadContextResult (never raises)
|
||||
# allow_failures=True is a safety net for any unexpected exceptions
|
||||
batch_results: list[ThreadContextResult | None] = (
|
||||
run_functions_tuples_in_parallel(
|
||||
[
|
||||
(
|
||||
_fetch_thread_context,
|
||||
(msg, access_token, team_id),
|
||||
)
|
||||
for msg in batch
|
||||
],
|
||||
allow_failures=True,
|
||||
max_workers=batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Process results - ThreadContextResult tells us exactly what happened
|
||||
for j, result in enumerate(batch_results):
|
||||
if result is None:
|
||||
# Unexpected exception (shouldn't happen) - use original text, stop
|
||||
logger.error(f"Unexpected None result for message {j} in batch")
|
||||
results.append(batch[j].text)
|
||||
rate_limited = True
|
||||
rate_limit_batch = current_batch
|
||||
elif result.is_rate_limited:
|
||||
# Rate limit hit - use original text, stop further batches
|
||||
results.append(result.text)
|
||||
rate_limited = True
|
||||
rate_limit_batch = current_batch
|
||||
else:
|
||||
# Success or recoverable error - use the text (enriched or original)
|
||||
results.append(result.text)
|
||||
|
||||
if rate_limited:
|
||||
logger.warning(
|
||||
f"Slack rate limit (429) hit at batch {current_batch}/{total_batches} "
|
||||
f"while fetching thread context. Stopping further API calls."
|
||||
)
|
||||
|
||||
# Add original text for messages we didn't fetch context for
|
||||
results.extend([msg.text for msg in messages_without_context])
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def convert_slack_score(slack_score: float) -> float:
|
||||
"""
|
||||
Convert slack score to a score between 0 and 1.
|
||||
@@ -1112,12 +964,11 @@ def slack_retrieval(
|
||||
if not slack_messages:
|
||||
return []
|
||||
|
||||
# Fetch thread context with rate limit handling and message limiting
|
||||
# Messages are already sorted by relevance (slack_score), so top N get full context
|
||||
thread_texts = fetch_thread_contexts_with_rate_limit_handling(
|
||||
slack_messages=slack_messages,
|
||||
access_token=access_token,
|
||||
team_id=team_id,
|
||||
thread_texts: list[str] = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(get_contextualized_thread_text, (slack_message, access_token, team_id))
|
||||
for slack_message in slack_messages
|
||||
]
|
||||
)
|
||||
for slack_message, thread_text in zip(slack_messages, thread_texts):
|
||||
slack_message.text = thread_text
|
||||
|
||||
@@ -90,16 +90,6 @@ def _build_index_filters(
|
||||
if not source_filter and detected_source_filter:
|
||||
source_filter = detected_source_filter
|
||||
|
||||
# CRITICAL FIX: If user_file_ids are present, we must ensure "user_file"
|
||||
# source type is included in the filter, otherwise user files will be excluded!
|
||||
if user_file_ids and source_filter:
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# Add user_file to the source filter if not already present
|
||||
if DocumentSource.USER_FILE not in source_filter:
|
||||
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
|
||||
logger.debug("Added USER_FILE to source_filter for user knowledge search")
|
||||
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
@@ -114,7 +104,6 @@ def _build_index_filters(
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
|
||||
)
|
||||
|
||||
return final_filters
|
||||
|
||||
|
||||
|
||||
@@ -44,7 +44,6 @@ def query_analysis(query: str) -> tuple[bool, list[str]]:
|
||||
return analysis_model.predict(query)
|
||||
|
||||
|
||||
# TODO: This is unused code.
|
||||
@log_function_time(print_only=True)
|
||||
def retrieval_preprocessing(
|
||||
search_request: SearchRequest,
|
||||
|
||||
@@ -118,7 +118,6 @@ def combine_retrieval_results(
|
||||
return sorted_chunks
|
||||
|
||||
|
||||
# TODO: This is unused code.
|
||||
@log_function_time(print_only=True)
|
||||
def doc_index_retrieval(
|
||||
query: SearchQuery,
|
||||
@@ -349,7 +348,6 @@ def retrieve_chunks(
|
||||
list(query.filters.source_type) if query.filters.source_type else None,
|
||||
query.filters.document_set,
|
||||
slack_context,
|
||||
query.filters.user_file_ids,
|
||||
)
|
||||
federated_sources = set(
|
||||
federated_retrieval_info.source.to_non_federated_source()
|
||||
@@ -477,7 +475,6 @@ def search_chunks(
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
slack_context=slack_context,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
|
||||
federated_sources = set(
|
||||
|
||||
@@ -63,7 +63,7 @@ def get_live_users_count(db_session: Session) -> int:
|
||||
This does NOT include invited users, "users" pulled in
|
||||
from external connectors, or API keys.
|
||||
"""
|
||||
count_stmt = func.count(User.id)
|
||||
count_stmt = func.count(User.id) # type: ignore
|
||||
select_stmt = select(count_stmt)
|
||||
select_stmt_w_filters = _add_live_user_count_where_clause(select_stmt, False)
|
||||
user_count = db_session.scalar(select_stmt_w_filters)
|
||||
@@ -74,7 +74,7 @@ def get_live_users_count(db_session: Session) -> int:
|
||||
|
||||
async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
async with get_async_session_context_manager() as session:
|
||||
count_stmt = func.count(User.id)
|
||||
count_stmt = func.count(User.id) # type: ignore
|
||||
stmt = select(count_stmt)
|
||||
stmt_w_filters = _add_live_user_count_where_clause(stmt, only_admin_users)
|
||||
user_count = await session.scalar(stmt_w_filters)
|
||||
@@ -100,10 +100,10 @@ class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
|
||||
async def get_user_db(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> AsyncGenerator[SQLAlchemyUserAdminDB, None]:
|
||||
yield SQLAlchemyUserAdminDB(session, User, OAuthAccount)
|
||||
yield SQLAlchemyUserAdminDB(session, User, OAuthAccount) # type: ignore
|
||||
|
||||
|
||||
async def get_access_token_db(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> AsyncGenerator[SQLAlchemyAccessTokenDatabase, None]:
|
||||
yield SQLAlchemyAccessTokenDatabase(session, AccessToken)
|
||||
yield SQLAlchemyAccessTokenDatabase(session, AccessToken) # type: ignore
|
||||
|
||||
@@ -40,21 +40,6 @@ def check_connectors_exist(db_session: Session) -> bool:
|
||||
return result.scalar() or False
|
||||
|
||||
|
||||
def check_user_files_exist(db_session: Session) -> bool:
|
||||
"""Check if any user files exist in the system.
|
||||
|
||||
This is used to determine if the search tool should be available
|
||||
when there are no regular connectors but there are user files
|
||||
(User Knowledge mode).
|
||||
"""
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.enums import UserFileStatus
|
||||
|
||||
stmt = select(exists(UserFile).where(UserFile.status == UserFileStatus.COMPLETED))
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar() or False
|
||||
|
||||
|
||||
def fetch_connectors(
|
||||
db_session: Session,
|
||||
sources: list[DocumentSource] | None = None,
|
||||
|
||||
@@ -290,7 +290,7 @@ def get_document_counts_for_cc_pairs(
|
||||
)
|
||||
)
|
||||
|
||||
for connector_id, credential_id, cnt in db_session.execute(stmt).all():
|
||||
for connector_id, credential_id, cnt in db_session.execute(stmt).all(): # type: ignore
|
||||
aggregated_counts[(connector_id, credential_id)] = cnt
|
||||
|
||||
# Convert aggregated results back to the expected sequence of tuples
|
||||
@@ -1098,7 +1098,7 @@ def reset_all_document_kg_stages(db_session: Session) -> int:
|
||||
|
||||
# The hasattr check is needed for type checking, even though rowcount
|
||||
# is guaranteed to exist at runtime for UPDATE operations
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0 # type: ignore
|
||||
|
||||
|
||||
def update_document_kg_stages(
|
||||
@@ -1121,7 +1121,7 @@ def update_document_kg_stages(
|
||||
result = db_session.execute(stmt)
|
||||
# The hasattr check is needed for type checking, even though rowcount
|
||||
# is guaranteed to exist at runtime for UPDATE operations
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0 # type: ignore
|
||||
|
||||
|
||||
def get_skipped_kg_documents(db_session: Session) -> list[str]:
|
||||
|
||||
@@ -2959,7 +2959,7 @@ class SlackChannelConfig(Base):
|
||||
"slack_bot_id",
|
||||
"is_default",
|
||||
unique=True,
|
||||
postgresql_where=(is_default is True),
|
||||
postgresql_where=(is_default is True), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -257,7 +257,7 @@ def _get_users_by_emails(
|
||||
"""given a list of lowercase emails,
|
||||
returns a list[User] of Users whose emails match and a list[str]
|
||||
the missing emails that had no User"""
|
||||
stmt = select(User).filter(func.lower(User.email).in_(lower_emails))
|
||||
stmt = select(User).filter(func.lower(User.email).in_(lower_emails)) # type: ignore
|
||||
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
|
||||
|
||||
# Extract found emails and convert to lowercase to avoid case sensitivity issues
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# TODO: Notes for potential extensions and future improvements:
|
||||
# 1. Allow tools that aren't search specific tools
|
||||
# 2. Use user provided custom prompts
|
||||
# 3. Save the plan for replay
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
@@ -38,7 +37,6 @@ from onyx.prompts.deep_research.orchestration_layer import ORCHESTRATOR_PROMPT_R
|
||||
from onyx.prompts.deep_research.orchestration_layer import RESEARCH_PLAN_PROMPT
|
||||
from onyx.prompts.deep_research.orchestration_layer import USER_FINAL_REPORT_QUERY
|
||||
from onyx.prompts.prompt_utils import get_current_llm_day_time
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import DeepResearchPlanDelta
|
||||
@@ -47,9 +45,9 @@ from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.fake_tools.research_agent import run_research_agent_calls
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
@@ -110,7 +108,7 @@ def generate_final_report(
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.NONE,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=999), # TODO
|
||||
turn_index=999, # TODO
|
||||
citation_processor=None,
|
||||
state_container=state_container,
|
||||
reasoning_effort=ReasoningEffort.LOW,
|
||||
@@ -187,7 +185,7 @@ def run_deep_research_llm_loop(
|
||||
tool_definitions=get_clarification_tool_definitions(),
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=0),
|
||||
turn_index=0,
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=None,
|
||||
@@ -200,9 +198,7 @@ def run_deep_research_llm_loop(
|
||||
# Mark this turn as a clarification question
|
||||
state_container.set_is_clarification(True)
|
||||
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=OverallStop(type="stop"))
|
||||
)
|
||||
emitter.emit(Packet(turn_index=0, obj=OverallStop(type="stop")))
|
||||
|
||||
# If a clarification is asked, we need to end this turn and wait on user input
|
||||
return
|
||||
@@ -233,7 +229,7 @@ def run_deep_research_llm_loop(
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.NONE,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=0),
|
||||
turn_index=0,
|
||||
citation_processor=None,
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
@@ -248,14 +244,14 @@ def run_deep_research_llm_loop(
|
||||
if isinstance(packet.obj, AgentResponseStart):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=packet.placement,
|
||||
turn_index=packet.turn_index,
|
||||
obj=DeepResearchPlanStart(),
|
||||
)
|
||||
)
|
||||
elif isinstance(packet.obj, AgentResponseDelta):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=packet.placement,
|
||||
turn_index=packet.turn_index,
|
||||
obj=DeepResearchPlanDelta(content=packet.obj.content),
|
||||
)
|
||||
)
|
||||
@@ -269,7 +265,7 @@ def run_deep_research_llm_loop(
|
||||
emitter.emit(
|
||||
Packet(
|
||||
# Marks the last turn end which should be the plan generation
|
||||
placement=Placement(turn_index=orchestrator_start_turn_index - 1),
|
||||
turn_index=orchestrator_start_turn_index - 1,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
@@ -344,9 +340,7 @@ def run_deep_research_llm_loop(
|
||||
),
|
||||
tool_choice=ToolChoiceOptions.REQUIRED,
|
||||
llm=llm,
|
||||
placement=Placement(
|
||||
turn_index=orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
),
|
||||
turn_index=orchestrator_start_turn_index + cycle + reasoning_cycles,
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
@@ -366,8 +360,6 @@ def run_deep_research_llm_loop(
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
# Basically hope that this is an infrequent occurence and hopefully multiple research
|
||||
# cycles have already ran
|
||||
logger.warning("No tool calls found, this should not happen.")
|
||||
generate_final_report(
|
||||
history=simple_chat_history,
|
||||
@@ -398,8 +390,6 @@ def run_deep_research_llm_loop(
|
||||
# This will not actually get saved to the db as a tool call but we'll attach it to the tool(s) called after
|
||||
# it as if it were just a reasoning model doing it. In the chat history, because it happens in 2 steps,
|
||||
# we will show it as a separate message.
|
||||
# NOTE: This does not need to increment the reasoning cycles because the custom token processor causes
|
||||
# the LLM step to handle this
|
||||
most_recent_reasoning = state_container.reasoning_tokens
|
||||
tool_call_message = think_tool_call.to_msg_str()
|
||||
|
||||
@@ -420,6 +410,7 @@ def run_deep_research_llm_loop(
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(think_tool_response_msg)
|
||||
reasoning_cycles += 1
|
||||
continue
|
||||
else:
|
||||
for tool_call in tool_calls:
|
||||
@@ -444,7 +435,6 @@ def run_deep_research_llm_loop(
|
||||
break
|
||||
|
||||
research_results = run_research_agent_calls(
|
||||
# The tool calls here contain the placement information
|
||||
research_agent_calls=research_agent_calls,
|
||||
tools=allowed_tools,
|
||||
emitter=emitter,
|
||||
|
||||
@@ -90,8 +90,6 @@ class MetadataUpdateRequest(BaseModel):
|
||||
the contents of the document.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
document_ids: list[str]
|
||||
# Passed in to help with potential optimizations of the implementation. The
|
||||
# keys should be redundant with document_ids.
|
||||
|
||||
@@ -273,7 +273,6 @@ schema {{ schema_name }} {
|
||||
# Boost based on aggregated boost calculation
|
||||
* aggregated_chunk_boost
|
||||
}
|
||||
# Target hits for hybrid retrieval should be at least this value.
|
||||
rerank-count: 1000
|
||||
}
|
||||
|
||||
@@ -342,7 +341,6 @@ schema {{ schema_name }} {
|
||||
# Boost based on aggregated boost calculation
|
||||
* aggregated_chunk_boost
|
||||
}
|
||||
# Target hits for hybrid retrieval should be at least this value.
|
||||
rerank-count: 1000
|
||||
}
|
||||
|
||||
|
||||
@@ -15,19 +15,19 @@ from typing import cast
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
import httpx # type: ignore
|
||||
import jinja2
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.configs.chat_configs import VESPA_SEARCHER_THREADS
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.context.search.enums import QueryType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
@@ -88,6 +88,7 @@ from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Set the logging level to WARNING to ignore INFO and DEBUG logs
|
||||
@@ -957,37 +958,48 @@ class VespaIndex(DocumentIndex):
|
||||
offset: int = 0,
|
||||
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
|
||||
) -> list[InferenceChunk]:
|
||||
tenant_id = filters.tenant_id if filters.tenant_id is not None else ""
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=TenantState(
|
||||
tenant_id=tenant_id,
|
||||
multitenant=self.multitenant,
|
||||
vespa_where_clauses = build_vespa_filters(filters)
|
||||
# Needs to be at least as much as the value set in Vespa schema config
|
||||
target_hits = max(10 * num_to_retrieve, 1000)
|
||||
|
||||
yql = (
|
||||
YQL_BASE.format(index_name=self.index_name)
|
||||
+ vespa_where_clauses
|
||||
+ f"(({{targetHits: {target_hits}}}nearestNeighbor(embeddings, query_embedding)) "
|
||||
+ f"or ({{targetHits: {target_hits}}}nearestNeighbor(title_embedding, query_embedding)) "
|
||||
+ 'or ({grammar: "weakAnd"}userInput(@query)) '
|
||||
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
|
||||
)
|
||||
|
||||
final_query = " ".join(final_keywords) if final_keywords else query
|
||||
|
||||
if ranking_profile_type == QueryExpansionType.KEYWORD:
|
||||
ranking_profile = f"hybrid_search_keyword_base_{len(query_embedding)}"
|
||||
else:
|
||||
ranking_profile = f"hybrid_search_semantic_base_{len(query_embedding)}"
|
||||
|
||||
logger.info(f"Selected ranking profile: {ranking_profile}")
|
||||
|
||||
logger.debug(f"Query YQL: {yql}")
|
||||
|
||||
params: dict[str, str | int | float] = {
|
||||
"yql": yql,
|
||||
"query": final_query,
|
||||
"input.query(query_embedding)": str(query_embedding),
|
||||
"input.query(decay_factor)": str(DOC_TIME_DECAY * time_decay_multiplier),
|
||||
"input.query(alpha)": hybrid_alpha,
|
||||
"input.query(title_content_ratio)": (
|
||||
title_content_ratio
|
||||
if title_content_ratio is not None
|
||||
else TITLE_CONTENT_RATIO
|
||||
),
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
if not (
|
||||
ranking_profile_type == QueryExpansionType.KEYWORD
|
||||
or ranking_profile_type == QueryExpansionType.SEMANTIC
|
||||
):
|
||||
raise ValueError(
|
||||
f"Bug: Received invalid ranking profile type: {ranking_profile_type}"
|
||||
)
|
||||
query_type = (
|
||||
QueryType.KEYWORD
|
||||
if ranking_profile_type == QueryExpansionType.KEYWORD
|
||||
else QueryType.SEMANTIC
|
||||
)
|
||||
return vespa_document_index.hybrid_retrieval(
|
||||
query,
|
||||
query_embedding,
|
||||
final_keywords,
|
||||
query_type,
|
||||
filters,
|
||||
num_to_retrieve,
|
||||
offset,
|
||||
)
|
||||
"hits": num_to_retrieve,
|
||||
"offset": offset,
|
||||
"ranking.profile": ranking_profile,
|
||||
"timeout": VESPA_TIMEOUT,
|
||||
}
|
||||
|
||||
return cleanup_chunks(query_vespa(params))
|
||||
|
||||
def admin_retrieval(
|
||||
self,
|
||||
|
||||
@@ -1,22 +1,11 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
|
||||
from onyx.configs.app_configs import RERANK_COUNT
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.context.search.enums import QueryType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.document_index_utils import get_document_chunk_ids
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
@@ -26,7 +15,6 @@ from onyx.document_index.interfaces_new import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces_new import DocumentSectionRequest
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import MetadataUpdateRequest
|
||||
from onyx.document_index.vespa.chunk_retrieval import query_vespa
|
||||
from onyx.document_index.vespa.deletion import delete_vespa_chunks
|
||||
from onyx.document_index.vespa.indexing_utils import BaseHTTPXClientContext
|
||||
from onyx.document_index.vespa.indexing_utils import batch_index_vespa_chunks
|
||||
@@ -35,32 +23,13 @@ from onyx.document_index.vespa.indexing_utils import clean_chunk_id_copy
|
||||
from onyx.document_index.vespa.indexing_utils import GlobalHTTPXClientContext
|
||||
from onyx.document_index.vespa.indexing_utils import TemporaryHTTPXClientContext
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
|
||||
build_vespa_filters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import BATCH_SIZE
|
||||
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import NUM_THREADS
|
||||
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
|
||||
from onyx.document_index.vespa_constants import YQL_BASE
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.tools.tool_implementations.search.constants import KEYWORD_QUERY_HYBRID_ALPHA
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# Set the logging level to WARNING to ignore INFO and DEBUG logs from httpx. By
|
||||
# default it emits INFO-level logs for every request.
|
||||
httpx_logger = logging.getLogger("httpx")
|
||||
httpx_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class TenantState(BaseModel):
|
||||
"""
|
||||
Captures the tenant-related state for an instance of VespaDocumentIndex.
|
||||
@@ -133,211 +102,6 @@ def _enrich_basic_chunk_info(
|
||||
return enriched_doc_info
|
||||
|
||||
|
||||
def _cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk]:
|
||||
"""Removes indexing-time content additions from chunks retrieved from Vespa.
|
||||
|
||||
During indexing, chunks are augmented with additional text to improve search
|
||||
quality:
|
||||
- Title prepended to content (for better keyword/semantic matching)
|
||||
- Metadata suffix appended to content
|
||||
- Contextual RAG: doc_summary (beginning) and chunk_context (end)
|
||||
|
||||
This function strips these additions before returning chunks to users,
|
||||
restoring the original document content. Cleaning is applied in sequence:
|
||||
1. Title removal:
|
||||
- Full match: Strips exact title from beginning
|
||||
- Partial match: If content starts with title[:BLURB_SIZE], splits on
|
||||
RETURN_SEPARATOR to remove title section
|
||||
2. Metadata suffix removal:
|
||||
- Strips metadata_suffix from end, plus trailing RETURN_SEPARATOR
|
||||
3. Contextual RAG removal:
|
||||
- Strips doc_summary from beginning (if present)
|
||||
- Strips chunk_context from end (if present)
|
||||
|
||||
Args:
|
||||
chunks: Chunks as retrieved from Vespa with indexing augmentations
|
||||
intact.
|
||||
|
||||
Returns:
|
||||
Clean InferenceChunk objects with augmentations removed, containing only
|
||||
the original document content that should be shown to users.
|
||||
"""
|
||||
|
||||
def _remove_title(chunk: InferenceChunkUncleaned) -> str:
|
||||
if not chunk.title or not chunk.content:
|
||||
return chunk.content
|
||||
|
||||
if chunk.content.startswith(chunk.title):
|
||||
return chunk.content[len(chunk.title) :].lstrip()
|
||||
|
||||
# BLURB SIZE is by token instead of char but each token is at least 1 char
|
||||
# If this prefix matches the content, it's assumed the title was prepended
|
||||
if chunk.content.startswith(chunk.title[:BLURB_SIZE]):
|
||||
return (
|
||||
chunk.content.split(RETURN_SEPARATOR, 1)[-1]
|
||||
if RETURN_SEPARATOR in chunk.content
|
||||
else chunk.content
|
||||
)
|
||||
return chunk.content
|
||||
|
||||
def _remove_metadata_suffix(chunk: InferenceChunkUncleaned) -> str:
|
||||
if not chunk.metadata_suffix:
|
||||
return chunk.content
|
||||
return chunk.content.removesuffix(chunk.metadata_suffix).rstrip(
|
||||
RETURN_SEPARATOR
|
||||
)
|
||||
|
||||
def _remove_contextual_rag(chunk: InferenceChunkUncleaned) -> str:
|
||||
# remove document summary
|
||||
if chunk.doc_summary and chunk.content.startswith(chunk.doc_summary):
|
||||
chunk.content = chunk.content[len(chunk.doc_summary) :].lstrip()
|
||||
# remove chunk context
|
||||
if chunk.chunk_context and chunk.content.endswith(chunk.chunk_context):
|
||||
chunk.content = chunk.content[
|
||||
: len(chunk.content) - len(chunk.chunk_context)
|
||||
].rstrip()
|
||||
return chunk.content
|
||||
|
||||
for chunk in chunks:
|
||||
chunk.content = _remove_title(chunk)
|
||||
chunk.content = _remove_metadata_suffix(chunk)
|
||||
chunk.content = _remove_contextual_rag(chunk)
|
||||
|
||||
return [chunk.to_inference_chunk() for chunk in chunks]
|
||||
|
||||
|
||||
@retry(
|
||||
tries=3,
|
||||
delay=1,
|
||||
backoff=2,
|
||||
exceptions=httpx.HTTPError,
|
||||
)
|
||||
def _update_single_chunk(
|
||||
doc_chunk_id: UUID,
|
||||
index_name: str,
|
||||
doc_id: str,
|
||||
http_client: httpx.Client,
|
||||
update_request: MetadataUpdateRequest,
|
||||
) -> None:
|
||||
"""Updates a single document chunk in Vespa.
|
||||
|
||||
TODO(andrei): Couldn't this be batched?
|
||||
|
||||
Args:
|
||||
doc_chunk_id: The ID of the chunk to update.
|
||||
index_name: The index the chunk belongs to.
|
||||
doc_id: The ID of the document the chunk belongs to.
|
||||
http_client: The HTTP client to use to make the request.
|
||||
update_request: Metadata update request object received in the bulk
|
||||
update method containing fields to update.
|
||||
"""
|
||||
|
||||
class _Boost(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
assign: float
|
||||
|
||||
class _DocumentSets(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
assign: dict[str, int]
|
||||
|
||||
class _AccessControl(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
assign: dict[str, int]
|
||||
|
||||
class _Hidden(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
assign: bool
|
||||
|
||||
class _UserProjects(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
assign: list[int]
|
||||
|
||||
class _VespaPutFields(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
# The names of these fields are based the Vespa schema. Changes to the
|
||||
# schema require changes here. These names were originally found in
|
||||
# backend/onyx/document_index/vespa_constants.py.
|
||||
boost: _Boost | None = None
|
||||
document_sets: _DocumentSets | None = None
|
||||
access_control_list: _AccessControl | None = None
|
||||
hidden: _Hidden | None = None
|
||||
user_project: _UserProjects | None = None
|
||||
|
||||
class _VespaPutRequest(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
fields: _VespaPutFields
|
||||
|
||||
boost_update: _Boost | None = (
|
||||
_Boost(assign=update_request.boost)
|
||||
if update_request.boost is not None
|
||||
else None
|
||||
)
|
||||
document_sets_update: _DocumentSets | None = (
|
||||
_DocumentSets(
|
||||
assign={document_set: 1 for document_set in update_request.document_sets}
|
||||
)
|
||||
if update_request.document_sets is not None
|
||||
else None
|
||||
)
|
||||
access_update: _AccessControl | None = (
|
||||
_AccessControl(
|
||||
assign={acl_entry: 1 for acl_entry in update_request.access.to_acl()}
|
||||
)
|
||||
if update_request.access is not None
|
||||
else None
|
||||
)
|
||||
hidden_update: _Hidden | None = (
|
||||
_Hidden(assign=update_request.hidden)
|
||||
if update_request.hidden is not None
|
||||
else None
|
||||
)
|
||||
user_projects_update: _UserProjects | None = (
|
||||
_UserProjects(assign=list(update_request.project_ids))
|
||||
if update_request.project_ids is not None
|
||||
else None
|
||||
)
|
||||
|
||||
vespa_put_fields = _VespaPutFields(
|
||||
boost=boost_update,
|
||||
document_sets=document_sets_update,
|
||||
access_control_list=access_update,
|
||||
hidden=hidden_update,
|
||||
user_project=user_projects_update,
|
||||
)
|
||||
|
||||
vespa_put_request = _VespaPutRequest(
|
||||
fields=vespa_put_fields,
|
||||
)
|
||||
|
||||
vespa_url = (
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}"
|
||||
"?create=true"
|
||||
)
|
||||
|
||||
try:
|
||||
resp = http_client.put(
|
||||
vespa_url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=vespa_put_request.model_dump(
|
||||
exclude_none=True
|
||||
), # NOTE: Important to not produce null fields in the json.
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). "
|
||||
f"Code: {e.response.status_code}. Details: {e.response.text}"
|
||||
)
|
||||
# Re-raise so the @retry decorator will catch and retry, unless the
|
||||
# status code is < 5xx, in which case wrap the exception in something
|
||||
# other than an HTTPError to skip retries.
|
||||
if e.response.status_code >= 500:
|
||||
raise
|
||||
raise RuntimeError(
|
||||
f"Non-retryable error updating chunk {doc_chunk_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
class VespaDocumentIndex(DocumentIndex):
|
||||
"""Vespa-specific implementation of the DocumentIndex interface.
|
||||
|
||||
@@ -371,10 +135,6 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
get_vespa_http_client
|
||||
)
|
||||
self._multitenant = tenant_state.multitenant
|
||||
if self._multitenant:
|
||||
assert (
|
||||
self._tenant_id
|
||||
), "Bug: Must supply a tenant id if in multitenant mode."
|
||||
|
||||
def verify_and_create_index_if_necessary(
|
||||
self, embedding_dim: int, embedding_precision: EmbeddingPrecision
|
||||
@@ -499,35 +259,7 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, update_requests: list[MetadataUpdateRequest]) -> None:
|
||||
with self._httpx_client_context as httpx_client:
|
||||
# Each invocation of this method can contain multiple update requests.
|
||||
for update_request in update_requests:
|
||||
# Each update request can correspond to multiple documents.
|
||||
for doc_id in update_request.document_ids:
|
||||
chunk_count = update_request.doc_id_to_chunk_cnt[doc_id]
|
||||
sanitized_doc_id = replace_invalid_doc_id_characters(doc_id)
|
||||
enriched_doc_info = _enrich_basic_chunk_info(
|
||||
index_name=self._index_name,
|
||||
http_client=httpx_client,
|
||||
document_id=sanitized_doc_id,
|
||||
previous_chunk_count=chunk_count,
|
||||
new_chunk_count=0, # WARNING: This semantically makes no sense and is misusing this function.
|
||||
)
|
||||
|
||||
doc_chunk_ids = get_document_chunk_ids(
|
||||
enriched_document_info_list=[enriched_doc_info],
|
||||
tenant_id=self._tenant_id,
|
||||
large_chunks_enabled=self._large_chunks_enabled,
|
||||
)
|
||||
|
||||
for doc_chunk_id in doc_chunk_ids:
|
||||
_update_single_chunk(
|
||||
doc_chunk_id,
|
||||
self._index_name,
|
||||
doc_id,
|
||||
httpx_client,
|
||||
update_request,
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def id_based_retrieval(
|
||||
self, chunk_requests: list[DocumentSectionRequest]
|
||||
@@ -544,56 +276,7 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
vespa_where_clauses = build_vespa_filters(filters)
|
||||
# Needs to be at least as much as the rerank-count value set in the
|
||||
# Vespa schema config. Otherwise we would be getting fewer results than
|
||||
# expected for reranking.
|
||||
target_hits = max(10 * num_to_retrieve, RERANK_COUNT)
|
||||
|
||||
yql = (
|
||||
YQL_BASE.format(index_name=self._index_name)
|
||||
+ vespa_where_clauses
|
||||
+ f"(({{targetHits: {target_hits}}}nearestNeighbor(embeddings, query_embedding)) "
|
||||
+ f"or ({{targetHits: {target_hits}}}nearestNeighbor(title_embedding, query_embedding)) "
|
||||
+ 'or ({grammar: "weakAnd"}userInput(@query)) '
|
||||
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
|
||||
)
|
||||
|
||||
final_query = " ".join(final_keywords) if final_keywords else query
|
||||
|
||||
ranking_profile = (
|
||||
f"hybrid_search_{query_type.value}_base_{len(query_embedding)}"
|
||||
)
|
||||
|
||||
logger.info(f"Selected ranking profile: {ranking_profile}")
|
||||
|
||||
logger.debug(f"Query YQL: {yql}")
|
||||
|
||||
# In this interface we do not pass in hybrid alpha. Tracing the codepath
|
||||
# of the legacy Vespa interface, it so happens that KEYWORD always
|
||||
# corresponds to an alpha of 0.2 (from KEYWORD_QUERY_HYBRID_ALPHA), and
|
||||
# SEMANTIC to 0.5 (from HYBRID_ALPHA). HYBRID_ALPHA_KEYWORD was only
|
||||
# used in dead code so we do not use it here.
|
||||
hybrid_alpha = (
|
||||
KEYWORD_QUERY_HYBRID_ALPHA
|
||||
if query_type == QueryType.KEYWORD
|
||||
else HYBRID_ALPHA
|
||||
)
|
||||
|
||||
params: dict[str, str | int | float] = {
|
||||
"yql": yql,
|
||||
"query": final_query,
|
||||
"input.query(query_embedding)": str(query_embedding),
|
||||
"input.query(decay_factor)": str(DOC_TIME_DECAY * RECENCY_BIAS_MULTIPLIER),
|
||||
"input.query(alpha)": hybrid_alpha,
|
||||
"input.query(title_content_ratio)": TITLE_CONTENT_RATIO,
|
||||
"hits": num_to_retrieve,
|
||||
"offset": offset,
|
||||
"ranking.profile": ranking_profile,
|
||||
"timeout": VESPA_TIMEOUT,
|
||||
}
|
||||
|
||||
return _cleanup_chunks(query_vespa(params))
|
||||
raise NotImplementedError
|
||||
|
||||
def random_retrieval(
|
||||
self,
|
||||
|
||||
@@ -34,7 +34,7 @@ class BraintrustEvalProvider(EvalProvider):
|
||||
eval_data = [EvalCase(input=item["input"]) for item in data]
|
||||
Eval(
|
||||
name=BRAINTRUST_PROJECT,
|
||||
data=eval_data,
|
||||
data=eval_data, # type: ignore[arg-type]
|
||||
task=task,
|
||||
scores=[],
|
||||
metadata={**configuration.model_dump()},
|
||||
|
||||
@@ -38,18 +38,7 @@ def get_federated_retrieval_functions(
|
||||
source_types: list[DocumentSource] | None,
|
||||
document_set_names: list[str] | None,
|
||||
slack_context: SlackContext | None = None,
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
) -> list[FederatedRetrievalInfo]:
|
||||
# When User Knowledge (user files) is the only knowledge source enabled,
|
||||
# skip federated connectors entirely. User Knowledge mode means the agent
|
||||
# should ONLY use uploaded files, not team connectors like Slack.
|
||||
if user_file_ids and not document_set_names:
|
||||
logger.debug(
|
||||
"Skipping all federated connectors: User Knowledge mode enabled "
|
||||
f"with {len(user_file_ids)} user files and no document sets"
|
||||
)
|
||||
return []
|
||||
|
||||
# Check for Slack bot context first (regardless of user_id)
|
||||
if slack_context:
|
||||
logger.debug("Slack context detected, checking for Slack bot setup...")
|
||||
|
||||
@@ -47,10 +47,10 @@ class SlackEntities(BaseModel):
|
||||
|
||||
# Message count per slack request
|
||||
max_messages_per_query: int = Field(
|
||||
default=10,
|
||||
default=25,
|
||||
description=(
|
||||
"Maximum number of messages to retrieve per search query. "
|
||||
"Higher values increase API calls and may trigger rate limits."
|
||||
"Higher values provide more context but may be slower."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ def delete_unstructured_api_key() -> None:
|
||||
def _sdk_partition_request(
|
||||
file: IO[Any], file_name: str, **kwargs: Any
|
||||
) -> "operations.PartitionRequest":
|
||||
from unstructured_client.models import operations
|
||||
from unstructured_client.models import operations # type: ignore
|
||||
from unstructured_client.models import shared
|
||||
|
||||
file.seek(0, 0)
|
||||
@@ -62,7 +62,7 @@ def unstructured_to_text(file: IO[Any], file_name: str) -> str:
|
||||
|
||||
unstructured_client = UnstructuredClient(api_key_auth=get_unstructured_api_key())
|
||||
|
||||
response = unstructured_client.general.partition(req)
|
||||
response = unstructured_client.general.partition(req) # type: ignore
|
||||
elements = dict_to_elements(response.elements)
|
||||
|
||||
if response.status_code != 200:
|
||||
|
||||
@@ -50,10 +50,10 @@ from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import IndexingBatchAdapter
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.factory import get_llm_for_contextual_rag
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.multi_llm import LLMRateLimitError
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.llm.utils import MAX_CONTEXT_TOKENS
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
|
||||
@@ -45,7 +45,9 @@ class PgRedisKVStore(KeyValueStore):
|
||||
obj.value = plain_val
|
||||
obj.encrypted_value = encrypted_val
|
||||
else:
|
||||
obj = KVStore(key=key, value=plain_val, encrypted_value=encrypted_val)
|
||||
obj = KVStore(
|
||||
key=key, value=plain_val, encrypted_value=encrypted_val
|
||||
) # type: ignore
|
||||
db_session.query(KVStore).filter_by(key=key).delete() # just in case
|
||||
db_session.add(obj)
|
||||
db_session.commit()
|
||||
@@ -95,7 +97,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.query(KVStore).filter_by(key=key).delete()
|
||||
result = db_session.query(KVStore).filter_by(key=key).delete() # type: ignore
|
||||
if result == 0:
|
||||
raise KvKeyNotFoundError
|
||||
db_session.commit()
|
||||
|
||||
@@ -14,12 +14,12 @@ from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.chat_llm import LitellmLLM
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.llm_provider_options import OLLAMA_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.llm_provider_options import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import OPENROUTER_PROVIDER_NAME
|
||||
from onyx.llm.multi_llm import LitellmLLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
|
||||
@@ -152,7 +152,7 @@ def litellm_exception_to_error_msg(
|
||||
if message_attr:
|
||||
upstream_detail = str(message_attr)
|
||||
elif hasattr(core_exception, "api_error"):
|
||||
api_error = core_exception.api_error
|
||||
api_error = core_exception.api_error # type: ignore[attr-defined]
|
||||
if isinstance(api_error, dict):
|
||||
upstream_detail = (
|
||||
api_error.get("message")
|
||||
|
||||
@@ -15,10 +15,10 @@ from typing import cast
|
||||
import aioboto3 # type: ignore
|
||||
import httpx
|
||||
import requests
|
||||
import voyageai # type: ignore[import-untyped]
|
||||
import voyageai # type: ignore
|
||||
from cohere import AsyncClient as CohereAsyncClient
|
||||
from cohere.core.api_error import ApiError
|
||||
from google.oauth2 import service_account
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from httpx import HTTPError
|
||||
from requests import JSONDecodeError
|
||||
from requests import RequestException
|
||||
@@ -89,6 +89,30 @@ _AUTH_ERROR_UNAUTHORIZED = "unauthorized"
|
||||
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
|
||||
_AUTH_ERROR_PERMISSION = "permission"
|
||||
|
||||
# Thread-local storage for event loops
|
||||
# This prevents creating thousands of event loops during batch processing,
|
||||
# which was causing severe memory leaks with API-based embedding providers
|
||||
_thread_local = threading.local()
|
||||
|
||||
|
||||
def _get_or_create_event_loop() -> asyncio.AbstractEventLoop:
|
||||
"""Get or create a thread-local event loop for API embedding calls.
|
||||
|
||||
This prevents creating a new event loop for every batch during embedding,
|
||||
which was causing memory leaks. Instead, each thread reuses the same loop.
|
||||
|
||||
Returns:
|
||||
asyncio.AbstractEventLoop: The thread-local event loop
|
||||
"""
|
||||
if (
|
||||
not hasattr(_thread_local, "loop")
|
||||
or _thread_local.loop is None
|
||||
or _thread_local.loop.is_closed()
|
||||
):
|
||||
_thread_local.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(_thread_local.loop)
|
||||
return _thread_local.loop
|
||||
|
||||
|
||||
WARM_UP_STRINGS = [
|
||||
"Onyx is amazing!",
|
||||
@@ -271,8 +295,8 @@ class CloudEmbedding:
|
||||
embedding_type: str,
|
||||
reduced_dimension: int | None,
|
||||
) -> list[Embedding]:
|
||||
from google import genai
|
||||
from google.genai import types as genai_types
|
||||
from google import genai # type: ignore[import-untyped]
|
||||
from google.genai import types as genai_types # type: ignore[import-untyped]
|
||||
|
||||
if not model:
|
||||
model = DEFAULT_VERTEX_MODEL
|
||||
@@ -776,16 +800,14 @@ class EmbeddingModel:
|
||||
# Route between direct API calls and model server calls
|
||||
if self.provider_type is not None:
|
||||
# For API providers, make direct API call
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
response = loop.run_until_complete(
|
||||
self._make_direct_api_call(
|
||||
embed_request, tenant_id=tenant_id, request_id=request_id
|
||||
)
|
||||
# Use thread-local event loop to prevent memory leaks from creating
|
||||
# thousands of event loops during batch processing
|
||||
loop = _get_or_create_event_loop()
|
||||
response = loop.run_until_complete(
|
||||
self._make_direct_api_call(
|
||||
embed_request, tenant_id=tenant_id, request_id=request_id
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
else:
|
||||
# For local models, use model server
|
||||
response = self._make_model_server_request(
|
||||
|
||||
@@ -3,8 +3,8 @@ from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from copy import copy
|
||||
|
||||
from tokenizers import Encoding # type: ignore[import-untyped]
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers import Encoding # type: ignore
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
|
||||
from onyx.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from mistune import Markdown # type: ignore[import-untyped]
|
||||
from mistune import Renderer
|
||||
from mistune import Markdown # type: ignore
|
||||
from mistune import Renderer # type: ignore
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
|
||||
@@ -452,7 +452,7 @@ def redis_lock_dump(lock: RedisLock, r: Redis) -> None:
|
||||
ttl = r.ttl(name)
|
||||
locked = lock.locked()
|
||||
owned = lock.owned()
|
||||
local_token: str | None = lock.local.token
|
||||
local_token: str | None = lock.local.token # type: ignore
|
||||
|
||||
remote_token_raw = r.get(lock.name)
|
||||
if remote_token_raw:
|
||||
|
||||
@@ -16,7 +16,7 @@ from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import UploadFile
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google.oauth2.credentials import Credentials # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Placement(BaseModel):
|
||||
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
|
||||
turn_index: int
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
@@ -11,7 +11,6 @@ from onyx.db.chat import get_db_search_doc_by_id
|
||||
from onyx.db.chat import translate_db_search_doc_to_saved_search_doc
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.tools import get_tool_by_id
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
@@ -60,7 +59,7 @@ def create_message_packets(
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_search_docs,
|
||||
),
|
||||
@@ -69,7 +68,7 @@ def create_message_packets(
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseDelta(
|
||||
content=message_text,
|
||||
),
|
||||
@@ -78,7 +77,7 @@ def create_message_packets(
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
@@ -95,12 +94,12 @@ def create_citation_packets(
|
||||
for citation_info in citation_info_list:
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=citation_info,
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
@@ -108,20 +107,18 @@ def create_citation_packets(
|
||||
def create_reasoning_packets(reasoning_text: str, turn_index: int) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(placement=Placement(turn_index=turn_index), obj=ReasoningStart())
|
||||
)
|
||||
packets.append(Packet(turn_index=turn_index, obj=ReasoningStart()))
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDelta(
|
||||
reasoning=reasoning_text,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
@@ -133,24 +130,21 @@ def create_image_generation_packets(
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=ImageGenerationToolStart(),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=ImageGenerationFinal(images=images),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
@@ -167,14 +161,16 @@ def create_custom_tool_packets(
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=CustomToolStart(tool_name=tool_name),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=CustomToolDelta(
|
||||
tool_name=tool_name,
|
||||
response_type=response_type,
|
||||
@@ -184,12 +180,7 @@ def create_custom_tool_packets(
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
@@ -204,32 +195,30 @@ def create_fetch_packets(
|
||||
# Emit start packet
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=OpenUrlStart(),
|
||||
)
|
||||
)
|
||||
# Emit URLs packet
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=OpenUrlUrls(urls=urls),
|
||||
)
|
||||
)
|
||||
# Emit documents packet
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=OpenUrlDocuments(
|
||||
documents=[SearchDoc(**doc.model_dump()) for doc in fetch_docs]
|
||||
),
|
||||
)
|
||||
)
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
|
||||
return packets
|
||||
|
||||
|
||||
@@ -244,7 +233,8 @@ def create_search_packets(
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=SearchToolStart(
|
||||
is_internet_search=is_internet_search,
|
||||
),
|
||||
@@ -255,7 +245,8 @@ def create_search_packets(
|
||||
if search_queries:
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=SearchToolQueriesDelta(queries=search_queries),
|
||||
),
|
||||
)
|
||||
@@ -267,7 +258,8 @@ def create_search_packets(
|
||||
)
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=SearchToolDocumentsDelta(
|
||||
documents=[
|
||||
SearchDoc(**doc.model_dump()) for doc in sorted_search_docs
|
||||
@@ -276,12 +268,7 @@ def create_search_packets(
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
@@ -450,8 +437,6 @@ def translate_assistant_message_to_packets(
|
||||
)
|
||||
|
||||
# Add overall stop packet at the end
|
||||
packet_list.append(
|
||||
Packet(placement=Placement(turn_index=final_turn_index), obj=OverallStop())
|
||||
)
|
||||
packet_list.append(Packet(turn_index=final_turn_index, obj=OverallStop()))
|
||||
|
||||
return packet_list
|
||||
|
||||
@@ -7,7 +7,6 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
|
||||
|
||||
class StreamingType(Enum):
|
||||
@@ -289,7 +288,20 @@ PacketObj = Union[
|
||||
]
|
||||
|
||||
|
||||
class Placement(BaseModel):
|
||||
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
|
||||
turn_index: int
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int
|
||||
|
||||
|
||||
class Packet(BaseModel):
|
||||
placement: Placement
|
||||
turn_index: int | None
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
|
||||
obj: Annotated[PacketObj, Field(discriminator="type")]
|
||||
|
||||
@@ -4,7 +4,6 @@ import re
|
||||
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
@@ -63,7 +62,7 @@ def create_message_packets(
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseStart(
|
||||
final_documents=SearchDoc.from_saved_search_docs(final_documents or []),
|
||||
),
|
||||
@@ -84,7 +83,7 @@ def create_message_packets(
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseDelta(
|
||||
content=adjusted_message_text,
|
||||
),
|
||||
@@ -93,7 +92,7 @@ def create_message_packets(
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
@@ -110,12 +109,12 @@ def create_citation_packets(
|
||||
for citation_info in citation_info_list:
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=citation_info,
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
@@ -123,20 +122,18 @@ def create_citation_packets(
|
||||
def create_reasoning_packets(reasoning_text: str, turn_index: int) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(placement=Placement(turn_index=turn_index), obj=ReasoningStart())
|
||||
)
|
||||
packets.append(Packet(turn_index=turn_index, obj=ReasoningStart()))
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDelta(
|
||||
reasoning=reasoning_text,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
@@ -148,19 +145,19 @@ def create_image_generation_packets(
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=ImageGenerationToolStart(),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=ImageGenerationFinal(images=images),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
@@ -176,14 +173,14 @@ def create_custom_tool_packets(
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=CustomToolStart(tool_name=tool_name),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=CustomToolDelta(
|
||||
tool_name=tool_name,
|
||||
response_type=response_type,
|
||||
@@ -193,7 +190,7 @@ def create_custom_tool_packets(
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
@@ -207,27 +204,27 @@ def create_fetch_packets(
|
||||
# Emit start packet
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=OpenUrlStart(),
|
||||
)
|
||||
)
|
||||
# Emit URLs packet
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=OpenUrlUrls(urls=urls),
|
||||
)
|
||||
)
|
||||
# Emit documents packet
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=OpenUrlDocuments(
|
||||
documents=SearchDoc.from_saved_search_docs(fetch_docs)
|
||||
),
|
||||
)
|
||||
)
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
|
||||
return packets
|
||||
|
||||
|
||||
@@ -241,7 +238,7 @@ def create_search_packets(
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=SearchToolStart(
|
||||
is_internet_search=is_internet_search,
|
||||
),
|
||||
@@ -252,7 +249,7 @@ def create_search_packets(
|
||||
if search_queries:
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=SearchToolQueriesDelta(queries=search_queries),
|
||||
),
|
||||
)
|
||||
@@ -261,13 +258,13 @@ def create_search_packets(
|
||||
if saved_search_docs:
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
turn_index=turn_index,
|
||||
obj=SearchToolDocumentsDelta(
|
||||
documents=SearchDoc.from_saved_search_docs(saved_search_docs)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
@@ -34,13 +34,13 @@ from onyx.prompts.deep_research.research_agent import RESEARCH_REPORT_PROMPT
|
||||
from onyx.prompts.deep_research.research_agent import USER_REPORT_QUERY
|
||||
from onyx.prompts.prompt_utils import get_current_llm_day_time
|
||||
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import ResearchAgentStart
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
@@ -93,7 +93,7 @@ def generate_intermediate_report(
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.NONE,
|
||||
llm=llm,
|
||||
placement=placement,
|
||||
turn_index=999, # TODO
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
reasoning_effort=ReasoningEffort.LOW,
|
||||
@@ -127,22 +127,24 @@ def run_research_agent_call(
|
||||
token_counter: Callable[[str], int],
|
||||
user_identity: LLMUserIdentity | None,
|
||||
) -> ResearchAgentCallResult:
|
||||
research_cycle_count = 0
|
||||
cycle_count = 0
|
||||
llm_cycle_count = 0
|
||||
current_tools = tools
|
||||
gathered_documents: list[SearchDoc] | None = None
|
||||
reasoning_cycles = 0
|
||||
just_ran_web_search = False
|
||||
|
||||
turn_index = research_agent_call.placement.turn_index
|
||||
tab_index = research_agent_call.placement.tab_index
|
||||
turn_index = research_agent_call.turn_index
|
||||
tab_index = research_agent_call.tab_index
|
||||
|
||||
# If this fails to parse, we can't run the loop anyway, let this one fail in that case
|
||||
research_topic = research_agent_call.tool_args[RESEARCH_AGENT_TASK_KEY]
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=0,
|
||||
obj=ResearchAgentStart(research_task=research_topic),
|
||||
)
|
||||
)
|
||||
@@ -155,9 +157,8 @@ def run_research_agent_call(
|
||||
msg_history: list[ChatMessageSimple] = [initial_user_message]
|
||||
|
||||
citation_mapping: dict[int, str] = {}
|
||||
while research_cycle_count <= RESEARCH_CYCLE_CAP:
|
||||
if research_cycle_count == RESEARCH_CYCLE_CAP:
|
||||
# For the last cycle, do not use any more searches, only reason or generate a report
|
||||
while cycle_count <= RESEARCH_CYCLE_CAP:
|
||||
if cycle_count == RESEARCH_CYCLE_CAP:
|
||||
current_tools = [
|
||||
tool
|
||||
for tool in tools
|
||||
@@ -194,7 +195,7 @@ def run_research_agent_call(
|
||||
system_prompt_str = system_prompt_template.format(
|
||||
available_tools=tools_description,
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=research_cycle_count,
|
||||
current_cycle_count=cycle_count,
|
||||
optional_internal_search_tool_description=internal_search_tip,
|
||||
optional_web_search_tool_description=web_search_tip,
|
||||
optional_open_urls_tool_description=open_urls_tip,
|
||||
@@ -234,11 +235,7 @@ def run_research_agent_call(
|
||||
+ research_agent_tools,
|
||||
tool_choice=ToolChoiceOptions.REQUIRED,
|
||||
llm=llm,
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=llm_cycle_count + reasoning_cycles,
|
||||
),
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
reasoning_effort=ReasoningEffort.LOW,
|
||||
@@ -253,13 +250,6 @@ def run_research_agent_call(
|
||||
|
||||
just_ran_web_search = False
|
||||
|
||||
if any(
|
||||
tool_call.tool_name in {SearchTool.NAME, WebSearchTool.NAME}
|
||||
for tool_call in tool_calls
|
||||
):
|
||||
# Only the search actions increment the cycle for the max cycle count
|
||||
research_cycle_count += 1
|
||||
|
||||
special_tool_calls = check_special_tool_calls(tool_calls=tool_calls)
|
||||
if special_tool_calls.generate_report_tool_call:
|
||||
final_report = generate_intermediate_report(
|
||||
@@ -271,10 +261,8 @@ def run_research_agent_call(
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=llm_cycle_count + reasoning_cycles,
|
||||
),
|
||||
turn_index=turn_index, tab_index=tab_index, sub_turn_index=0
|
||||
), # TODO
|
||||
)
|
||||
return ResearchAgentCallResult(
|
||||
intermediate_report=final_report, search_docs=[]
|
||||
@@ -328,7 +316,7 @@ def run_research_agent_call(
|
||||
raise ValueError("Tool response missing tool_call reference")
|
||||
|
||||
tool_call = tool_response.tool_call
|
||||
tab_index = tool_call.placement.tab_index
|
||||
tab_index = tool_call.tab_index
|
||||
|
||||
tool = tools_by_name.get(tool_call.tool_name)
|
||||
if not tool:
|
||||
@@ -404,10 +392,8 @@ def run_research_agent_call(
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=llm_cycle_count + reasoning_cycles,
|
||||
),
|
||||
turn_index=turn_index, tab_index=tab_index, sub_turn_index=0
|
||||
), # TODO
|
||||
)
|
||||
return ResearchAgentCallResult(intermediate_report=final_report, search_docs=[])
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.tools.tool_implementations.images.models import FinalImageGenerationResponse
|
||||
|
||||
@@ -39,7 +38,8 @@ class ToolCallKickoff(BaseModel):
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
|
||||
placement: Placement
|
||||
turn_index: int
|
||||
tab_index: int
|
||||
|
||||
def to_msg_str(self) -> str:
|
||||
return json.dumps(
|
||||
|
||||
@@ -3,13 +3,15 @@ from __future__ import annotations
|
||||
import abc
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import ToolResponse
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.orm import Session
|
||||
from onyx.tools.models import ToolResponse
|
||||
|
||||
|
||||
TOverride = TypeVar("TOverride")
|
||||
@@ -71,7 +73,7 @@ class Tool(abc.ABC, Generic[TOverride]):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def emit_start(self, placement: Placement) -> None:
|
||||
def emit_start(self, turn_index: int, tab_index: int) -> None:
|
||||
"""
|
||||
Emit the start packet for this tool. Each tool implementation should
|
||||
emit its specific start packet type.
|
||||
@@ -85,7 +87,11 @@ class Tool(abc.ABC, Generic[TOverride]):
|
||||
@abc.abstractmethod
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
# The run must know its turn because the "Tool" may actually be more of an "Agent" which can call
|
||||
# other tools and must pass in this information potentially deeper down.
|
||||
turn_index: int,
|
||||
# Tab index for parallel tool calls (default 0 for single tool calls)
|
||||
tab_index: int,
|
||||
# Specific tool override arguments that are not provided by the LLM
|
||||
# For example when calling the internal search tool, the original user query is passed along too (but not by the LLM)
|
||||
override_kwargs: TOverride,
|
||||
@@ -30,8 +30,8 @@ from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import DynamicSchemaInfo
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
|
||||
@@ -14,17 +14,16 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import CHAT_SESSION_ID_PLACEHOLDER
|
||||
from onyx.tools.models import CustomToolCallSummary
|
||||
from onyx.tools.models import CustomToolUserFileSnapshot
|
||||
from onyx.tools.models import DynamicSchemaInfo
|
||||
from onyx.tools.models import MESSAGE_ID_PLACEHOLDER
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.custom.openapi_parsing import MethodSpec
|
||||
from onyx.tools.tool_implementations.custom.openapi_parsing import (
|
||||
openapi_to_method_specs,
|
||||
@@ -134,17 +133,19 @@ class CustomTool(Tool[None]):
|
||||
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
def emit_start(self, placement: Placement) -> None:
|
||||
def emit_start(self, turn_index: int, tab_index: int) -> None:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=CustomToolStart(tool_name=self._name),
|
||||
)
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
turn_index: int,
|
||||
tab_index: int,
|
||||
override_kwargs: None = None,
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
@@ -211,7 +212,8 @@ class CustomTool(Tool[None]):
|
||||
# Emit CustomToolDelta packet
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=CustomToolDelta(
|
||||
tool_name=self._name,
|
||||
response_type=response_type,
|
||||
|
||||
@@ -13,14 +13,13 @@ from onyx.configs.app_configs import IMAGE_MODEL_NAME
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.file_store.utils import build_frontend_file_url
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.images.models import (
|
||||
FinalImageGenerationResponse,
|
||||
)
|
||||
@@ -121,10 +120,11 @@ class ImageGenerationTool(Tool[None]):
|
||||
},
|
||||
}
|
||||
|
||||
def emit_start(self, placement: Placement) -> None:
|
||||
def emit_start(self, turn_index: int, tab_index: int) -> None:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=ImageGenerationToolStart(),
|
||||
)
|
||||
)
|
||||
@@ -132,7 +132,7 @@ class ImageGenerationTool(Tool[None]):
|
||||
def _generate_image(
|
||||
self, prompt: str, shape: ImageShape
|
||||
) -> tuple[ImageGenerationResponse, Any]:
|
||||
from litellm import image_generation
|
||||
from litellm import image_generation # type: ignore
|
||||
|
||||
if shape == ImageShape.LANDSCAPE:
|
||||
if "gpt-image-1" in self.model:
|
||||
@@ -209,7 +209,8 @@ class ImageGenerationTool(Tool[None]):
|
||||
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
turn_index: int,
|
||||
tab_index: int,
|
||||
override_kwargs: None = None,
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
@@ -258,7 +259,8 @@ class ImageGenerationTool(Tool[None]):
|
||||
# Emit a heartbeat packet to prevent timeout
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=ImageGenerationToolHeartbeat(),
|
||||
)
|
||||
)
|
||||
@@ -302,7 +304,8 @@ class ImageGenerationTool(Tool[None]):
|
||||
# Emit final packet with generated images
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=ImageGenerationFinal(images=generated_images_metadata),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4,9 +4,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -69,12 +68,13 @@ class KnowledgeGraphTool(Tool[None]):
|
||||
},
|
||||
}
|
||||
|
||||
def emit_start(self, placement: Placement) -> None:
|
||||
def emit_start(self, turn_index: int, tab_index: int) -> None:
|
||||
raise NotImplementedError("KnowledgeGraphTool.emit_start is not implemented.")
|
||||
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
turn_index: int,
|
||||
tab_index: int,
|
||||
override_kwargs: None = None,
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
|
||||
@@ -6,13 +6,12 @@ from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
from onyx.db.models import MCPConnectionConfig
|
||||
from onyx.db.models import MCPServer
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import CustomToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.mcp.mcp_client import call_mcp_tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -88,17 +87,19 @@ class MCPTool(Tool[None]):
|
||||
},
|
||||
}
|
||||
|
||||
def emit_start(self, placement: Placement) -> None:
|
||||
def emit_start(self, turn_index: int, tab_index: int) -> None:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=CustomToolStart(tool_name=self._name),
|
||||
)
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
turn_index: int,
|
||||
tab_index: int,
|
||||
override_kwargs: None = None,
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
@@ -145,7 +146,8 @@ class MCPTool(Tool[None]):
|
||||
# Emit CustomToolDelta packet
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=CustomToolDelta(
|
||||
tool_name=self._name,
|
||||
response_type="json",
|
||||
@@ -180,7 +182,8 @@ class MCPTool(Tool[None]):
|
||||
# Emit CustomToolDelta packet
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=CustomToolDelta(
|
||||
tool_name=self._name,
|
||||
response_type="json",
|
||||
@@ -234,7 +237,8 @@ class MCPTool(Tool[None]):
|
||||
# Emit CustomToolDelta packet
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=CustomToolDelta(
|
||||
tool_name=self._name,
|
||||
response_type="json",
|
||||
|
||||
@@ -9,14 +9,13 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.context.search.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OpenUrlDocuments
|
||||
from onyx.server.query_and_chat.streaming_models import OpenUrlStart
|
||||
from onyx.server.query_and_chat.streaming_models import OpenUrlUrls
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import OpenURLToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.open_url.models import WebContentProvider
|
||||
from onyx.tools.tool_implementations.web_search.providers import (
|
||||
get_default_content_provider,
|
||||
@@ -179,25 +178,28 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
|
||||
},
|
||||
}
|
||||
|
||||
def emit_start(self, placement: Placement) -> None:
|
||||
def emit_start(self, turn_index: int, tab_index: int) -> None:
|
||||
"""Emit start packet to signal tool has started."""
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=OpenUrlStart(),
|
||||
)
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
turn_index: int,
|
||||
tab_index: int,
|
||||
override_kwargs: OpenURLToolOverrideKwargs,
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Execute the open URL tool to fetch content from the specified URLs.
|
||||
|
||||
Args:
|
||||
placement: The placement info (turn_index and tab_index) for this tool call.
|
||||
turn_index: The current turn index in the conversation.
|
||||
tab_index: The tab index for parallel tool calls.
|
||||
override_kwargs: Override arguments including starting citation number
|
||||
and existing citation_mapping to reuse citations for already-cited URLs.
|
||||
**llm_kwargs: Arguments provided by the LLM, including the 'urls' field.
|
||||
@@ -209,7 +211,8 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
|
||||
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=OpenUrlUrls(urls=urls),
|
||||
)
|
||||
)
|
||||
@@ -245,7 +248,8 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
|
||||
# Emit documents packet AFTER crawling completes
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=OpenUrlDocuments(documents=search_docs),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -14,15 +14,14 @@ from onyx.configs.app_configs import CODE_INTERPRETER_MAX_OUTPUT_LENGTH
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.utils import build_full_frontend_file_url
|
||||
from onyx.file_store.utils import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolStart
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import LlmPythonExecutionResult
|
||||
from onyx.tools.models import PythonExecutionFile
|
||||
from onyx.tools.models import PythonToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
@@ -115,7 +114,7 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
},
|
||||
}
|
||||
|
||||
def emit_start(self, placement: Placement) -> None:
|
||||
def emit_start(self, turn_index: int, tab_index: int) -> None:
|
||||
"""Emit start packet for this tool. Code will be emitted in run() method."""
|
||||
# Note: PythonToolStart requires code, but we don't have it in emit_start
|
||||
# The code is available in run() method via llm_kwargs
|
||||
@@ -123,7 +122,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
turn_index: int,
|
||||
tab_index: int,
|
||||
override_kwargs: PythonToolOverrideKwargs,
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
@@ -131,7 +131,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
Execute Python code in the Code Interpreter service.
|
||||
|
||||
Args:
|
||||
placement: The placement info (turn_index and tab_index) for this tool call.
|
||||
turn_index: The turn index for this tool execution
|
||||
tab_index: The tab index for parallel tool calls
|
||||
override_kwargs: Contains chat_files to stage for execution
|
||||
**llm_kwargs: Contains 'code' parameter from LLM
|
||||
|
||||
@@ -144,7 +145,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
# Emit start event with the code
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=PythonToolStart(code=code),
|
||||
)
|
||||
)
|
||||
@@ -251,7 +253,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
# Emit delta with stdout/stderr and generated files
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=PythonToolDelta(
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
@@ -286,7 +289,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
# Emit error delta
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=PythonToolDelta(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
|
||||
@@ -64,14 +64,13 @@ from onyx.secondary_llm_flows.document_filter import select_chunks_for_relevance
|
||||
from onyx.secondary_llm_flows.document_filter import select_sections_for_expansion
|
||||
from onyx.secondary_llm_flows.query_expansion import keyword_query_expansion
|
||||
from onyx.secondary_llm_flows.query_expansion import semantic_query_rephrase
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDocumentsDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolQueriesDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.constants import (
|
||||
KEYWORD_QUERY_HYBRID_ALPHA,
|
||||
)
|
||||
@@ -303,19 +302,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
|
||||
@classmethod
|
||||
def is_available(cls, db_session: Session) -> bool:
|
||||
"""Check if search tool is available.
|
||||
|
||||
The search tool is available if ANY of the following exist:
|
||||
- Regular connectors (team knowledge)
|
||||
- Federated connectors (e.g., Slack)
|
||||
- User files (User Knowledge mode)
|
||||
"""
|
||||
from onyx.db.connector import check_user_files_exist
|
||||
|
||||
return (
|
||||
check_connectors_exist(db_session)
|
||||
or check_federated_connectors_exist(db_session)
|
||||
or check_user_files_exist(db_session)
|
||||
"""Check if search tool is available by verifying connectors exist."""
|
||||
return check_connectors_exist(db_session) or check_federated_connectors_exist(
|
||||
db_session
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -356,10 +345,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
},
|
||||
}
|
||||
|
||||
def emit_start(self, placement: Placement) -> None:
|
||||
def emit_start(self, turn_index: int, tab_index: int) -> None:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=SearchToolStart(),
|
||||
)
|
||||
)
|
||||
@@ -367,7 +357,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
@log_function_time(print_only=True)
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
turn_index: int,
|
||||
tab_index: int,
|
||||
override_kwargs: SearchToolOverrideKwargs,
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
@@ -487,7 +478,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
# Emit the queries early so the UI can display them immediately
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=SearchToolQueriesDelta(
|
||||
queries=all_queries,
|
||||
),
|
||||
@@ -595,7 +587,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=SearchToolDocumentsDelta(
|
||||
documents=final_ui_docs,
|
||||
),
|
||||
|
||||
@@ -9,14 +9,13 @@ from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.context.search.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.web_search import fetch_active_web_search_provider
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDocumentsDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolQueriesDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.models import WebSearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.utils import (
|
||||
convert_inference_sections_to_llm_string,
|
||||
)
|
||||
@@ -111,10 +110,11 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
},
|
||||
}
|
||||
|
||||
def emit_start(self, placement: Placement) -> None:
|
||||
def emit_start(self, turn_index: int, tab_index: int) -> None:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=SearchToolStart(is_internet_search=True),
|
||||
)
|
||||
)
|
||||
@@ -129,7 +129,8 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
turn_index: int,
|
||||
tab_index: int,
|
||||
override_kwargs: WebSearchToolOverrideKwargs,
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
@@ -139,7 +140,8 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
# Emit queries
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=SearchToolQueriesDelta(queries=queries),
|
||||
)
|
||||
)
|
||||
@@ -203,7 +205,8 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
# Emit documents
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=SearchToolDocumentsDelta(documents=search_docs),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -9,13 +9,13 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatMinimalTextMessage
|
||||
from onyx.tools.models import OpenURLToolOverrideKwargs
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.models import WebSearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
@@ -73,8 +73,9 @@ def _merge_tool_calls(tool_calls: list[ToolCallKickoff]) -> list[ToolCallKickoff
|
||||
tool_call_id=calls[0].tool_call_id, # Use first call's ID
|
||||
tool_name=tool_name,
|
||||
tool_args=merged_args,
|
||||
# Use first call's placement since merged calls become a single call
|
||||
placement=calls[0].placement,
|
||||
turn_index=calls[0].turn_index,
|
||||
# Use first call's tab_index since merged calls become a single call
|
||||
tab_index=calls[0].tab_index,
|
||||
)
|
||||
merged_calls.append(merged_call)
|
||||
else:
|
||||
@@ -93,12 +94,16 @@ def _run_single_tool(
|
||||
|
||||
This function is designed to be run in parallel via run_functions_tuples_in_parallel.
|
||||
"""
|
||||
turn_index = tool_call.turn_index
|
||||
tab_index = tool_call.tab_index
|
||||
|
||||
with function_span(tool.name) as span_fn:
|
||||
span_fn.span_data.input = str(tool_call.tool_args)
|
||||
try:
|
||||
tool_response = tool.run(
|
||||
placement=tool_call.placement,
|
||||
turn_index=turn_index,
|
||||
override_kwargs=override_kwargs,
|
||||
tab_index=tab_index,
|
||||
**tool_call.tool_args,
|
||||
)
|
||||
span_fn.span_data.output = tool_response.llm_facing_response
|
||||
@@ -122,7 +127,8 @@ def _run_single_tool(
|
||||
# Emit SectionEnd after tool completes (success or failure)
|
||||
tool.emitter.emit(
|
||||
Packet(
|
||||
placement=tool_call.placement,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
@@ -204,7 +210,7 @@ def run_tool_calls(
|
||||
tool = tools_by_name[tool_call.tool_name]
|
||||
|
||||
# Emit the tool start packet before running the tool
|
||||
tool.emit_start(placement=tool_call.placement)
|
||||
tool.emit_start(turn_index=tool_call.turn_index, tab_index=tool_call.tab_index)
|
||||
|
||||
override_kwargs: (
|
||||
SearchToolOverrideKwargs
|
||||
|
||||
@@ -9,7 +9,7 @@ from onyx.db.models import LLMProvider
|
||||
from onyx.llm.utils import find_model_obj
|
||||
from onyx.llm.utils import get_model_map
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
def explicit_tool_calling_supported(model_provider: str, model_name: str) -> bool:
|
||||
|
||||
@@ -74,20 +74,20 @@ class BraintrustTracingProcessor(TracingProcessor):
|
||||
|
||||
current_context = braintrust.current_span()
|
||||
if current_context != NOOP_SPAN:
|
||||
self._spans[trace.trace_id] = current_context.start_span(
|
||||
self._spans[trace.trace_id] = current_context.start_span( # type: ignore[assignment]
|
||||
name=trace.name,
|
||||
span_attributes={"type": "task", "name": trace.name},
|
||||
metadata=metadata,
|
||||
)
|
||||
elif self._logger is not None:
|
||||
self._spans[trace.trace_id] = self._logger.start_span(
|
||||
self._spans[trace.trace_id] = self._logger.start_span( # type: ignore[assignment]
|
||||
span_attributes={"type": "task", "name": trace.name},
|
||||
span_id=trace.trace_id,
|
||||
root_span_id=trace.trace_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
else:
|
||||
self._spans[trace.trace_id] = braintrust.start_span(
|
||||
self._spans[trace.trace_id] = braintrust.start_span( # type: ignore[assignment]
|
||||
id=trace.trace_id,
|
||||
span_attributes={"type": "task", "name": trace.name},
|
||||
metadata=metadata,
|
||||
|
||||
@@ -208,7 +208,7 @@ def _as_utc_nano(dt: datetime) -> int:
|
||||
def _get_span_name(obj: Span[Any]) -> str:
|
||||
if hasattr(data := obj.span_data, "name") and isinstance(name := data.name, str):
|
||||
return name
|
||||
return obj.span_data.type
|
||||
return obj.span_data.type # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def _get_span_kind(obj: SpanData) -> str:
|
||||
|
||||
@@ -221,7 +221,7 @@ def run_functions_tuples_in_parallel(
|
||||
results.append((index, future.result()))
|
||||
except Exception as e:
|
||||
logger.exception(f"Function at index {index} failed due to {e}")
|
||||
results.append((index, None))
|
||||
results.append((index, None)) # type: ignore
|
||||
|
||||
if not allow_failures:
|
||||
raise
|
||||
@@ -336,7 +336,7 @@ def run_with_timeout(
|
||||
if task.is_alive():
|
||||
task.end()
|
||||
|
||||
return task.result
|
||||
return task.result # type: ignore
|
||||
|
||||
|
||||
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
|
||||
@@ -352,7 +352,7 @@ def run_in_background(
|
||||
"""
|
||||
context = contextvars.copy_context()
|
||||
# Timeout not used in the non-blocking case
|
||||
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
|
||||
task = TimeoutThread(-1, context.run, func, *args, **kwargs) # type: ignore
|
||||
task.start()
|
||||
return cast(TimeoutThread[R], task)
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ plugins = "sqlalchemy.ext.mypy.plugin"
|
||||
mypy_path = "backend"
|
||||
explicit_package_bases = true
|
||||
disallow_untyped_defs = true
|
||||
warn_unused_ignores = true
|
||||
enable_error_code = ["possibly-undefined"]
|
||||
strict_equality = true
|
||||
exclude = [
|
||||
|
||||
@@ -82,10 +82,6 @@ colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# tqdm
|
||||
comm==0.2.3
|
||||
# via ipykernel
|
||||
contourpy==1.3.3
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
debugpy==1.8.17
|
||||
# via ipykernel
|
||||
decorator==5.2.1
|
||||
@@ -116,8 +112,6 @@ filelock==3.20.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# virtualenv
|
||||
fonttools==4.61.1
|
||||
# via matplotlib
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
@@ -239,16 +233,12 @@ jupyter-core==5.9.1
|
||||
# via
|
||||
# ipykernel
|
||||
# jupyter-client
|
||||
kiwisolver==1.4.9
|
||||
# via matplotlib
|
||||
litellm==1.79.0
|
||||
# via onyx
|
||||
manygo==0.2.0
|
||||
# via onyx
|
||||
markupsafe==3.0.3
|
||||
# via jinja2
|
||||
matplotlib==3.10.8
|
||||
# via onyx
|
||||
matplotlib-inline==0.2.1
|
||||
# via
|
||||
# ipykernel
|
||||
@@ -271,8 +261,6 @@ nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
numpy==1.26.4
|
||||
# via
|
||||
# contourpy
|
||||
# matplotlib
|
||||
# pandas-stubs
|
||||
# shapely
|
||||
# voyageai
|
||||
@@ -294,7 +282,6 @@ packaging==24.2
|
||||
# hatchling
|
||||
# huggingface-hub
|
||||
# ipykernel
|
||||
# matplotlib
|
||||
# pytest
|
||||
pandas-stubs==2.2.3.241009
|
||||
# via onyx
|
||||
@@ -308,8 +295,6 @@ pathspec==0.12.1
|
||||
# hatchling
|
||||
pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
|
||||
# via ipython
|
||||
pillow==12.0.0
|
||||
# via matplotlib
|
||||
platformdirs==4.5.0
|
||||
# via
|
||||
# black
|
||||
@@ -378,8 +363,6 @@ pygments==2.19.2
|
||||
# via
|
||||
# ipython
|
||||
# ipython-pygments-lexers
|
||||
pyparsing==3.2.5
|
||||
# via matplotlib
|
||||
pytest==8.3.5
|
||||
# via
|
||||
# onyx
|
||||
@@ -398,7 +381,6 @@ python-dateutil==2.8.2
|
||||
# botocore
|
||||
# google-cloud-bigquery
|
||||
# jupyter-client
|
||||
# matplotlib
|
||||
python-dotenv==1.1.1
|
||||
# via
|
||||
# litellm
|
||||
|
||||
@@ -31,7 +31,7 @@ def blob_connector(request: pytest.FixtureRequest) -> BlobStorageConnector:
|
||||
)
|
||||
"""
|
||||
try:
|
||||
bucket_type, bucket_name, *rest = request.param
|
||||
bucket_type, bucket_name, *rest = request.param # type: ignore[misc]
|
||||
except Exception as e:
|
||||
raise AssertionError(
|
||||
"blob_connector requires (BlobType, bucket_name, [init_kwargs])"
|
||||
|
||||
@@ -36,7 +36,7 @@ def _prepare_connector(exclude: bool) -> GoogleDriveConnector:
|
||||
exclude_domain_link_only=exclude,
|
||||
)
|
||||
connector._creds = object() # type: ignore[assignment]
|
||||
connector._primary_admin_email = "admin@example.com"
|
||||
connector._primary_admin_email = "admin@example.com" # type: ignore[attr-defined]
|
||||
return connector
|
||||
|
||||
|
||||
|
||||
@@ -1,442 +0,0 @@
|
||||
"""
|
||||
External dependency unit tests for document processing job priority.
|
||||
|
||||
Tests that first-time indexing connectors (no last_successful_index_time)
|
||||
get higher priority than re-indexing jobs from connectors that have
|
||||
previously completed indexing.
|
||||
|
||||
Uses real Redis for locking and real database objects for CC pairs and search settings.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.docfetching.task_creation_utils import (
|
||||
try_creating_docfetching_task,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
def _create_test_connector(db_session: Session, name: str) -> Connector:
|
||||
"""Create a test connector with all required fields."""
|
||||
connector = Connector(
|
||||
name=name,
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={},
|
||||
refresh_freq=3600,
|
||||
)
|
||||
db_session.add(connector)
|
||||
db_session.commit()
|
||||
db_session.refresh(connector)
|
||||
return connector
|
||||
|
||||
|
||||
def _create_test_credential(db_session: Session) -> Credential:
|
||||
"""Create a test credential with all required fields."""
|
||||
credential = Credential(
|
||||
name=f"test_credential_{uuid4().hex[:8]}",
|
||||
source=DocumentSource.FILE,
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.commit()
|
||||
db_session.refresh(credential)
|
||||
return credential
|
||||
|
||||
|
||||
def _create_test_cc_pair(
|
||||
db_session: Session,
|
||||
connector: Connector,
|
||||
credential: Credential,
|
||||
status: ConnectorCredentialPairStatus,
|
||||
name: str,
|
||||
last_successful_index_time: datetime | None = None,
|
||||
) -> ConnectorCredentialPair:
|
||||
"""Create a connector credential pair with the specified status."""
|
||||
cc_pair = ConnectorCredentialPair(
|
||||
name=name,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
status=status,
|
||||
access_type=AccessType.PUBLIC,
|
||||
last_successful_index_time=last_successful_index_time,
|
||||
)
|
||||
db_session.add(cc_pair)
|
||||
db_session.commit()
|
||||
db_session.refresh(cc_pair)
|
||||
return cc_pair
|
||||
|
||||
|
||||
def _create_test_search_settings(
|
||||
db_session: Session, index_name: str
|
||||
) -> SearchSettings:
|
||||
"""Create test search settings with all required fields."""
|
||||
search_settings = SearchSettings(
|
||||
model_name="test-model",
|
||||
model_dim=768,
|
||||
normalize=True,
|
||||
query_prefix="",
|
||||
passage_prefix="",
|
||||
status=IndexModelStatus.PRESENT,
|
||||
index_name=index_name,
|
||||
embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
)
|
||||
db_session.add(search_settings)
|
||||
db_session.commit()
|
||||
db_session.refresh(search_settings)
|
||||
return search_settings
|
||||
|
||||
|
||||
class TestDocfetchingTaskPriorityWithRealObjects:
|
||||
"""
|
||||
Tests for document fetching task priority based on last_successful_index_time.
|
||||
|
||||
Uses real Redis for locking and real database objects for CC pairs
|
||||
and search settings.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"has_successful_index,expected_priority",
|
||||
[
|
||||
# First-time indexing (no last_successful_index_time) should get HIGH priority
|
||||
(False, OnyxCeleryPriority.HIGH),
|
||||
# Re-indexing (has last_successful_index_time) should get MEDIUM priority
|
||||
(True, OnyxCeleryPriority.MEDIUM),
|
||||
],
|
||||
)
|
||||
@patch(
|
||||
"onyx.background.celery.tasks.docfetching.task_creation_utils.IndexingCoordination.try_create_index_attempt"
|
||||
)
|
||||
def test_priority_based_on_last_successful_index_time(
|
||||
self,
|
||||
mock_try_create_index_attempt: MagicMock,
|
||||
db_session: Session,
|
||||
has_successful_index: bool,
|
||||
expected_priority: OnyxCeleryPriority,
|
||||
) -> None:
|
||||
"""
|
||||
Test that first-time indexing connectors get higher priority than re-indexing.
|
||||
|
||||
Priority is determined by last_successful_index_time:
|
||||
- None (never indexed): HIGH priority
|
||||
- Has timestamp (previously indexed): MEDIUM priority
|
||||
|
||||
Uses real Redis for locking and real database objects.
|
||||
"""
|
||||
# Create unique names to avoid conflicts between test runs
|
||||
unique_suffix = uuid4().hex[:8]
|
||||
|
||||
# Determine last_successful_index_time based on the test case
|
||||
last_successful_index_time = (
|
||||
datetime.now(timezone.utc) if has_successful_index else None
|
||||
)
|
||||
|
||||
# Create real database objects
|
||||
connector = _create_test_connector(
|
||||
db_session, f"test_connector_{has_successful_index}_{unique_suffix}"
|
||||
)
|
||||
credential = _create_test_credential(db_session)
|
||||
cc_pair = _create_test_cc_pair(
|
||||
db_session,
|
||||
connector,
|
||||
credential,
|
||||
ConnectorCredentialPairStatus.ACTIVE,
|
||||
name=f"test_cc_pair_{has_successful_index}_{unique_suffix}",
|
||||
last_successful_index_time=last_successful_index_time,
|
||||
)
|
||||
search_settings = _create_test_search_settings(
|
||||
db_session, f"test_index_{unique_suffix}"
|
||||
)
|
||||
|
||||
# Mock the index attempt creation to return a valid ID
|
||||
mock_try_create_index_attempt.return_value = 12345
|
||||
|
||||
# Mock celery app to capture task submission
|
||||
mock_celery_app = MagicMock()
|
||||
mock_celery_app.send_task.return_value = MagicMock()
|
||||
|
||||
# Use real Redis client
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# Call the function with real objects
|
||||
result = try_creating_docfetching_task(
|
||||
celery_app=mock_celery_app,
|
||||
cc_pair=cc_pair,
|
||||
search_settings=search_settings,
|
||||
reindex=False,
|
||||
db_session=db_session,
|
||||
r=redis_client,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
)
|
||||
|
||||
# Verify task was created
|
||||
assert result == 12345
|
||||
|
||||
# Verify send_task was called with the expected priority
|
||||
mock_celery_app.send_task.assert_called_once()
|
||||
call_kwargs = mock_celery_app.send_task.call_args
|
||||
actual_priority = call_kwargs.kwargs["priority"]
|
||||
assert actual_priority == expected_priority, (
|
||||
f"Expected priority {expected_priority} for has_successful_index={has_successful_index}, "
|
||||
f"but got {actual_priority}"
|
||||
)
|
||||
|
||||
@patch(
|
||||
"onyx.background.celery.tasks.docfetching.task_creation_utils.IndexingCoordination.try_create_index_attempt"
|
||||
)
|
||||
def test_no_task_created_when_deleting(
|
||||
self,
|
||||
mock_try_create_index_attempt: MagicMock,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Test that no task is created when connector is in DELETING status."""
|
||||
unique_suffix = uuid4().hex[:8]
|
||||
|
||||
connector = _create_test_connector(
|
||||
db_session, f"test_connector_deleting_{unique_suffix}"
|
||||
)
|
||||
credential = _create_test_credential(db_session)
|
||||
cc_pair = _create_test_cc_pair(
|
||||
db_session,
|
||||
connector,
|
||||
credential,
|
||||
ConnectorCredentialPairStatus.DELETING,
|
||||
name=f"test_cc_pair_deleting_{unique_suffix}",
|
||||
)
|
||||
search_settings = _create_test_search_settings(
|
||||
db_session, f"test_index_deleting_{unique_suffix}"
|
||||
)
|
||||
|
||||
mock_celery_app = MagicMock()
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
result = try_creating_docfetching_task(
|
||||
celery_app=mock_celery_app,
|
||||
cc_pair=cc_pair,
|
||||
search_settings=search_settings,
|
||||
reindex=False,
|
||||
db_session=db_session,
|
||||
r=redis_client,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
)
|
||||
|
||||
# Verify no task was created
|
||||
assert result is None
|
||||
mock_celery_app.send_task.assert_not_called()
|
||||
mock_try_create_index_attempt.assert_not_called()
|
||||
|
||||
@patch(
|
||||
"onyx.background.celery.tasks.docfetching.task_creation_utils.IndexingCoordination.try_create_index_attempt"
|
||||
)
|
||||
def test_redis_lock_prevents_concurrent_task_creation(
|
||||
self,
|
||||
mock_try_create_index_attempt: MagicMock,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Test that the Redis lock prevents concurrent task creation attempts.
|
||||
|
||||
This test uses real Redis to verify the locking mechanism works correctly.
|
||||
When the lock is already held, the function should return None without
|
||||
attempting to create a task.
|
||||
"""
|
||||
unique_suffix = uuid4().hex[:8]
|
||||
|
||||
connector = _create_test_connector(
|
||||
db_session, f"test_connector_lock_{unique_suffix}"
|
||||
)
|
||||
credential = _create_test_credential(db_session)
|
||||
cc_pair = _create_test_cc_pair(
|
||||
db_session,
|
||||
connector,
|
||||
credential,
|
||||
ConnectorCredentialPairStatus.INITIAL_INDEXING,
|
||||
name=f"test_cc_pair_lock_{unique_suffix}",
|
||||
)
|
||||
search_settings = _create_test_search_settings(
|
||||
db_session, f"test_index_lock_{unique_suffix}"
|
||||
)
|
||||
|
||||
mock_try_create_index_attempt.return_value = 12345
|
||||
mock_celery_app = MagicMock()
|
||||
mock_celery_app.send_task.return_value = MagicMock()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# Acquire the lock before calling the function
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
|
||||
lock = redis_client.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
try:
|
||||
acquired = lock.acquire(blocking=False)
|
||||
assert acquired, "Failed to acquire lock for test"
|
||||
|
||||
# Now try to create a task - should fail because lock is held
|
||||
result = try_creating_docfetching_task(
|
||||
celery_app=mock_celery_app,
|
||||
cc_pair=cc_pair,
|
||||
search_settings=search_settings,
|
||||
reindex=False,
|
||||
db_session=db_session,
|
||||
r=redis_client,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
)
|
||||
|
||||
# Should return None because lock couldn't be acquired
|
||||
assert result is None
|
||||
mock_celery_app.send_task.assert_not_called()
|
||||
|
||||
finally:
|
||||
# Always release the lock
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
@patch(
|
||||
"onyx.background.celery.tasks.docfetching.task_creation_utils.IndexingCoordination.try_create_index_attempt"
|
||||
)
|
||||
def test_lock_released_after_successful_task_creation(
|
||||
self,
|
||||
mock_try_create_index_attempt: MagicMock,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Test that the Redis lock is released after successful task creation.
|
||||
|
||||
This verifies that subsequent calls can acquire the lock and create tasks.
|
||||
"""
|
||||
unique_suffix = uuid4().hex[:8]
|
||||
|
||||
connector = _create_test_connector(
|
||||
db_session, f"test_connector_release_{unique_suffix}"
|
||||
)
|
||||
credential = _create_test_credential(db_session)
|
||||
cc_pair = _create_test_cc_pair(
|
||||
db_session,
|
||||
connector,
|
||||
credential,
|
||||
ConnectorCredentialPairStatus.INITIAL_INDEXING,
|
||||
name=f"test_cc_pair_release_{unique_suffix}",
|
||||
)
|
||||
search_settings = _create_test_search_settings(
|
||||
db_session, f"test_index_release_{unique_suffix}"
|
||||
)
|
||||
|
||||
mock_try_create_index_attempt.return_value = 12345
|
||||
mock_celery_app = MagicMock()
|
||||
mock_celery_app.send_task.return_value = MagicMock()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# First call should succeed
|
||||
result1 = try_creating_docfetching_task(
|
||||
celery_app=mock_celery_app,
|
||||
cc_pair=cc_pair,
|
||||
search_settings=search_settings,
|
||||
reindex=False,
|
||||
db_session=db_session,
|
||||
r=redis_client,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
)
|
||||
assert result1 == 12345
|
||||
|
||||
# Reset mocks for second call
|
||||
mock_celery_app.reset_mock()
|
||||
mock_try_create_index_attempt.reset_mock()
|
||||
mock_try_create_index_attempt.return_value = 67890
|
||||
|
||||
# Second call should also succeed (lock was released)
|
||||
result2 = try_creating_docfetching_task(
|
||||
celery_app=mock_celery_app,
|
||||
cc_pair=cc_pair,
|
||||
search_settings=search_settings,
|
||||
reindex=False,
|
||||
db_session=db_session,
|
||||
r=redis_client,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
)
|
||||
assert result2 == 67890
|
||||
|
||||
# Both calls should have submitted tasks
|
||||
mock_celery_app.send_task.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"onyx.background.celery.tasks.docfetching.task_creation_utils.IndexingCoordination.try_create_index_attempt"
|
||||
)
|
||||
def test_user_file_connector_uses_correct_queue(
|
||||
self,
|
||||
mock_try_create_index_attempt: MagicMock,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Test that user file connectors use the USER_FILES_INDEXING queue.
|
||||
"""
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
|
||||
unique_suffix = uuid4().hex[:8]
|
||||
|
||||
connector = _create_test_connector(
|
||||
db_session, f"test_connector_userfile_{unique_suffix}"
|
||||
)
|
||||
credential = _create_test_credential(db_session)
|
||||
cc_pair = _create_test_cc_pair(
|
||||
db_session,
|
||||
connector,
|
||||
credential,
|
||||
ConnectorCredentialPairStatus.INITIAL_INDEXING,
|
||||
name=f"test_cc_pair_userfile_{unique_suffix}",
|
||||
)
|
||||
# Mark as user file
|
||||
cc_pair.is_user_file = True
|
||||
db_session.commit()
|
||||
db_session.refresh(cc_pair)
|
||||
|
||||
search_settings = _create_test_search_settings(
|
||||
db_session, f"test_index_userfile_{unique_suffix}"
|
||||
)
|
||||
|
||||
mock_try_create_index_attempt.return_value = 12345
|
||||
mock_celery_app = MagicMock()
|
||||
mock_celery_app.send_task.return_value = MagicMock()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
result = try_creating_docfetching_task(
|
||||
celery_app=mock_celery_app,
|
||||
cc_pair=cc_pair,
|
||||
search_settings=search_settings,
|
||||
reindex=False,
|
||||
db_session=db_session,
|
||||
r=redis_client,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
)
|
||||
|
||||
assert result == 12345
|
||||
mock_celery_app.send_task.assert_called_once()
|
||||
call_kwargs = mock_celery_app.send_task.call_args
|
||||
assert call_kwargs.kwargs["queue"] == OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
# User files with no last_successful_index_time should get HIGH priority
|
||||
assert call_kwargs.kwargs["priority"] == OnyxCeleryPriority.HIGH
|
||||
@@ -1,258 +0,0 @@
|
||||
"""
|
||||
External dependency unit tests for docprocessing task priority.
|
||||
|
||||
Tests that docprocessing tasks spawned by connector_document_extraction
|
||||
get the correct priority based on last_successful_index_time.
|
||||
|
||||
Uses real database objects for CC pairs, search settings, and index attempts.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.indexing.run_docfetching import connector_document_extraction
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SearchSettings
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
def _create_test_connector(db_session: Session, name: str) -> Connector:
|
||||
"""Create a test connector with all required fields."""
|
||||
connector = Connector(
|
||||
name=name,
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={},
|
||||
refresh_freq=3600,
|
||||
)
|
||||
db_session.add(connector)
|
||||
db_session.commit()
|
||||
db_session.refresh(connector)
|
||||
return connector
|
||||
|
||||
|
||||
def _create_test_credential(db_session: Session) -> Credential:
|
||||
"""Create a test credential with all required fields."""
|
||||
credential = Credential(
|
||||
name=f"test_credential_{uuid4().hex[:8]}",
|
||||
source=DocumentSource.FILE,
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.commit()
|
||||
db_session.refresh(credential)
|
||||
return credential
|
||||
|
||||
|
||||
def _create_test_cc_pair(
|
||||
db_session: Session,
|
||||
connector: Connector,
|
||||
credential: Credential,
|
||||
status: ConnectorCredentialPairStatus,
|
||||
name: str,
|
||||
last_successful_index_time: datetime | None = None,
|
||||
) -> ConnectorCredentialPair:
|
||||
"""Create a connector credential pair with the specified status."""
|
||||
cc_pair = ConnectorCredentialPair(
|
||||
name=name,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
status=status,
|
||||
access_type=AccessType.PUBLIC,
|
||||
last_successful_index_time=last_successful_index_time,
|
||||
)
|
||||
db_session.add(cc_pair)
|
||||
db_session.commit()
|
||||
db_session.refresh(cc_pair)
|
||||
return cc_pair
|
||||
|
||||
|
||||
def _create_test_search_settings(
|
||||
db_session: Session, index_name: str
|
||||
) -> SearchSettings:
|
||||
"""Create test search settings with all required fields."""
|
||||
search_settings = SearchSettings(
|
||||
model_name="test-model",
|
||||
model_dim=768,
|
||||
normalize=True,
|
||||
query_prefix="",
|
||||
passage_prefix="",
|
||||
status=IndexModelStatus.PRESENT,
|
||||
index_name=index_name,
|
||||
embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
)
|
||||
db_session.add(search_settings)
|
||||
db_session.commit()
|
||||
db_session.refresh(search_settings)
|
||||
return search_settings
|
||||
|
||||
|
||||
def _create_test_index_attempt(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
from_beginning: bool = False,
|
||||
) -> IndexAttempt:
|
||||
"""Create a test index attempt with the specified cc_pair and search_settings."""
|
||||
index_attempt = IndexAttempt(
|
||||
connector_credential_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
from_beginning=from_beginning,
|
||||
status=IndexingStatus.IN_PROGRESS,
|
||||
celery_task_id=f"test_celery_task_{uuid4().hex[:8]}",
|
||||
)
|
||||
db_session.add(index_attempt)
|
||||
db_session.commit()
|
||||
db_session.refresh(index_attempt)
|
||||
return index_attempt
|
||||
|
||||
|
||||
class TestDocprocessingPriorityInDocumentExtraction:
|
||||
"""
|
||||
Tests for docprocessing task priority within connector_document_extraction.
|
||||
|
||||
Verifies that the priority passed to docprocessing tasks is determined
|
||||
by last_successful_index_time on the cc_pair.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"has_successful_index,expected_priority",
|
||||
[
|
||||
# First-time indexing (no last_successful_index_time) should get HIGH priority
|
||||
(False, OnyxCeleryPriority.HIGH),
|
||||
# Re-indexing (has last_successful_index_time) should get MEDIUM priority
|
||||
(True, OnyxCeleryPriority.MEDIUM),
|
||||
],
|
||||
)
|
||||
@patch("onyx.background.indexing.run_docfetching.get_document_batch_storage")
|
||||
@patch("onyx.background.indexing.run_docfetching.MemoryTracer")
|
||||
@patch("onyx.background.indexing.run_docfetching._get_connector_runner")
|
||||
@patch(
|
||||
"onyx.background.indexing.run_docfetching.get_recent_completed_attempts_for_cc_pair"
|
||||
)
|
||||
@patch(
|
||||
"onyx.background.indexing.run_docfetching.get_last_successful_attempt_poll_range_end"
|
||||
)
|
||||
@patch("onyx.background.indexing.run_docfetching.save_checkpoint")
|
||||
@patch("onyx.background.indexing.run_docfetching.get_latest_valid_checkpoint")
|
||||
def test_docprocessing_priority_based_on_last_successful_index_time(
|
||||
self,
|
||||
mock_get_latest_valid_checkpoint: MagicMock,
|
||||
mock_save_checkpoint: MagicMock,
|
||||
mock_get_last_successful_attempt_poll_range_end: MagicMock,
|
||||
mock_get_recent_completed_attempts: MagicMock,
|
||||
mock_get_connector_runner: MagicMock,
|
||||
mock_memory_tracer_class: MagicMock,
|
||||
mock_get_batch_storage: MagicMock,
|
||||
db_session: Session,
|
||||
has_successful_index: bool,
|
||||
expected_priority: OnyxCeleryPriority,
|
||||
) -> None:
|
||||
"""
|
||||
Test that docprocessing tasks get the correct priority based on
|
||||
last_successful_index_time.
|
||||
|
||||
Priority is determined by last_successful_index_time:
|
||||
- None (never indexed): HIGH priority
|
||||
- Has timestamp (previously indexed): MEDIUM priority
|
||||
|
||||
Uses real database objects for CC pairs and search settings.
|
||||
"""
|
||||
unique_suffix = uuid4().hex[:8]
|
||||
|
||||
# Determine last_successful_index_time based on the test case
|
||||
last_successful_index_time = (
|
||||
datetime.now(timezone.utc) if has_successful_index else None
|
||||
)
|
||||
|
||||
# Create real database objects
|
||||
connector = _create_test_connector(
|
||||
db_session, f"test_connector_docproc_{has_successful_index}_{unique_suffix}"
|
||||
)
|
||||
credential = _create_test_credential(db_session)
|
||||
cc_pair = _create_test_cc_pair(
|
||||
db_session,
|
||||
connector,
|
||||
credential,
|
||||
ConnectorCredentialPairStatus.ACTIVE,
|
||||
name=f"test_cc_pair_docproc_{has_successful_index}_{unique_suffix}",
|
||||
last_successful_index_time=last_successful_index_time,
|
||||
)
|
||||
search_settings = _create_test_search_settings(
|
||||
db_session, f"test_index_docproc_{unique_suffix}"
|
||||
)
|
||||
index_attempt = _create_test_index_attempt(
|
||||
db_session, cc_pair, search_settings, from_beginning=False
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_batch_storage = MagicMock()
|
||||
mock_get_batch_storage.return_value = mock_batch_storage
|
||||
|
||||
mock_memory_tracer = MagicMock()
|
||||
mock_memory_tracer_class.return_value = mock_memory_tracer
|
||||
|
||||
# Create checkpoint mocks - initial checkpoint has_more=True, final has_more=False
|
||||
mock_initial_checkpoint = MagicMock(has_more=True)
|
||||
mock_final_checkpoint = MagicMock(has_more=False)
|
||||
|
||||
# get_latest_valid_checkpoint returns (checkpoint, resuming_from_checkpoint)
|
||||
mock_get_latest_valid_checkpoint.return_value = (mock_initial_checkpoint, False)
|
||||
|
||||
# Create a mock connector runner that yields one document batch
|
||||
mock_connector = MagicMock()
|
||||
mock_connector_runner = MagicMock()
|
||||
mock_connector_runner.connector = mock_connector
|
||||
# The connector runner yields (document_batch, failure, next_checkpoint)
|
||||
# We provide one batch of documents to trigger a send_task call
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.to_short_descriptor.return_value = "test_doc"
|
||||
mock_doc.sections = []
|
||||
mock_connector_runner.run.return_value = iter(
|
||||
[([mock_doc], None, mock_final_checkpoint)]
|
||||
)
|
||||
mock_get_connector_runner.return_value = mock_connector_runner
|
||||
|
||||
mock_get_recent_completed_attempts.return_value = iter([])
|
||||
mock_get_last_successful_attempt_poll_range_end.return_value = 0
|
||||
|
||||
# Mock celery app to capture task submission
|
||||
mock_celery_app = MagicMock()
|
||||
mock_celery_app.send_task.return_value = MagicMock()
|
||||
|
||||
# Call the function
|
||||
connector_document_extraction(
|
||||
app=mock_celery_app,
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
callback=None,
|
||||
)
|
||||
|
||||
# Verify send_task was called with the expected priority for docprocessing
|
||||
assert mock_celery_app.send_task.called, "send_task should have been called"
|
||||
call_kwargs = mock_celery_app.send_task.call_args
|
||||
actual_priority = call_kwargs.kwargs["priority"]
|
||||
assert actual_priority == expected_priority, (
|
||||
f"Expected priority {expected_priority} for has_successful_index={has_successful_index}, "
|
||||
f"but got {actual_priority}"
|
||||
)
|
||||
@@ -31,7 +31,6 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import CustomToolCallSummary
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
@@ -388,9 +387,7 @@ class TestMCPPassThroughOAuth:
|
||||
):
|
||||
# Run the tool
|
||||
response = mcp_tool.run(
|
||||
placement=Placement(turn_index=0, tab_index=0),
|
||||
override_kwargs=None,
|
||||
input="test",
|
||||
turn_index=0, tab_index=0, override_kwargs=None, input="test"
|
||||
)
|
||||
print(response.rich_response)
|
||||
assert isinstance(response.rich_response, CustomToolCallSummary)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import generated.onyx_openapi_client.onyx_openapi_client as onyx_api # type: ignore[import-untyped,unused-ignore]
|
||||
import generated.onyx_openapi_client.onyx_openapi_client as onyx_api # type: ignore[import]
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
|
||||
api_config = onyx_api.Configuration(host=API_SERVER_URL)
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
"""
|
||||
Utilities for testing document access control lists (ACLs) and permissions.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.access.access import _get_access_for_documents
|
||||
from ee.onyx.db.external_perm import fetch_external_groups_for_user
|
||||
from onyx.access.utils import prefix_external_group
|
||||
from onyx.access.utils import prefix_user_email
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_user_acl(user: User, db_session: Session) -> set[str]:
|
||||
"""
|
||||
Get the ACL entries for a user, including their external groups, email, and public doc pattern.
|
||||
|
||||
Args:
|
||||
user: The user object
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
Set of ACL entries for the user
|
||||
"""
|
||||
db_external_groups = (
|
||||
fetch_external_groups_for_user(db_session, user.id) if user else []
|
||||
)
|
||||
prefixed_external_groups = [
|
||||
prefix_external_group(db_external_group.external_user_group_id)
|
||||
for db_external_group in db_external_groups
|
||||
]
|
||||
|
||||
user_acl = set(prefixed_external_groups)
|
||||
user_acl.update({prefix_user_email(user.email), PUBLIC_DOC_PAT})
|
||||
return user_acl
|
||||
|
||||
|
||||
def get_user_document_access_via_acl(
|
||||
test_user: DATestUser, document_ids: List[str], db_session: Session
|
||||
) -> List[str]:
|
||||
"""
|
||||
Determine which documents a user can access by comparing user ACL with document ACLs.
|
||||
|
||||
This is a more reliable method than search-based verification as it directly checks
|
||||
permission logic without depending on search relevance or ranking.
|
||||
|
||||
Args:
|
||||
test_user: The test user to check access for
|
||||
document_ids: List of document IDs to check
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
List of document IDs that the user can access
|
||||
"""
|
||||
# Get the actual User object from the database
|
||||
user = fetch_user_by_id(db_session, UUID(test_user.id))
|
||||
if not user:
|
||||
logger.error(f"Could not find user with ID {test_user.id}")
|
||||
return []
|
||||
|
||||
user_acl = get_user_acl(user, db_session)
|
||||
logger.info(f"User {user.email} ACL entries: {user_acl}")
|
||||
|
||||
# Get document access information
|
||||
doc_access_map = _get_access_for_documents(document_ids, db_session)
|
||||
logger.info(f"Found access info for {len(doc_access_map)} documents")
|
||||
|
||||
accessible_docs = []
|
||||
for doc_id, doc_access in doc_access_map.items():
|
||||
doc_acl = doc_access.to_acl()
|
||||
logger.info(f"Document {doc_id} ACL: {doc_acl}")
|
||||
|
||||
# Check if user has any matching ACL entry
|
||||
if user_acl.intersection(doc_acl):
|
||||
accessible_docs.append(doc_id)
|
||||
logger.info(f"User {user.email} has access to document {doc_id}")
|
||||
else:
|
||||
logger.info(f"User {user.email} does NOT have access to document {doc_id}")
|
||||
|
||||
return accessible_docs
|
||||
|
||||
|
||||
def get_all_connector_documents(
|
||||
cc_pair: DATestCCPair, db_session: Session
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get all document IDs for a given connector/credential pair.
|
||||
|
||||
Args:
|
||||
cc_pair: The connector-credential pair
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
List of document IDs
|
||||
"""
|
||||
stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
DocumentByConnectorCredentialPair.connector_id == cc_pair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == cc_pair.credential_id,
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
document_ids = [row[0] for row in result.fetchall()]
|
||||
logger.info(
|
||||
f"Found {len(document_ids)} documents for connector {cc_pair.connector_id}"
|
||||
)
|
||||
|
||||
return document_ids
|
||||
|
||||
|
||||
def get_documents_by_permission_type(
|
||||
document_ids: List[str], db_session: Session
|
||||
) -> List[str]:
|
||||
"""
|
||||
Categorize documents by their permission types and return public documents.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to check
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
List of document IDs that are public
|
||||
"""
|
||||
doc_access_map = _get_access_for_documents(document_ids, db_session)
|
||||
|
||||
public_docs = []
|
||||
|
||||
for doc_id, doc_access in doc_access_map.items():
|
||||
if doc_access.is_public:
|
||||
public_docs.append(doc_id)
|
||||
|
||||
return public_docs
|
||||
@@ -5,7 +5,7 @@ from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
import generated.onyx_openapi_client.onyx_openapi_client as api # type: ignore[import-untyped,unused-ignore]
|
||||
import generated.onyx_openapi_client.onyx_openapi_client as api # type: ignore[import]
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
|
||||
@@ -49,14 +49,6 @@ class StreamPacketObj(TypedDict, total=False):
|
||||
documents: list[dict[str, Any]]
|
||||
|
||||
|
||||
class PlacementData(TypedDict, total=False):
|
||||
"""Structure for packet placement information."""
|
||||
|
||||
turn_index: int
|
||||
tab_index: int
|
||||
sub_turn_index: int | None
|
||||
|
||||
|
||||
class StreamPacketData(TypedDict, total=False):
|
||||
"""Structure for streaming response packets."""
|
||||
|
||||
@@ -64,7 +56,7 @@ class StreamPacketData(TypedDict, total=False):
|
||||
error: str
|
||||
stack_trace: str
|
||||
obj: StreamPacketObj
|
||||
placement: PlacementData
|
||||
turn_index: int
|
||||
|
||||
|
||||
class ChatSessionManager:
|
||||
@@ -200,7 +192,7 @@ class ChatSessionManager:
|
||||
(
|
||||
data.get("ind")
|
||||
if data.get("ind") is not None
|
||||
else data.get("placement", {}).get("turn_index")
|
||||
else data.get("turn_index")
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestConnector
|
||||
from tests.integration.common_utils.test_models import DATestCredential
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
GitHubTestEnvSetupTuple = tuple[
|
||||
DATestUser, # admin_user
|
||||
DATestUser, # test_user_1
|
||||
DATestUser, # test_user_2
|
||||
DATestCredential, # github_credential
|
||||
DATestConnector, # github_connector
|
||||
DATestCCPair, # github_cc_pair
|
||||
]
|
||||
|
||||
|
||||
def _get_github_test_tokens() -> list[str]:
|
||||
"""
|
||||
Returns a list of GitHub tokens to run the GitHub connector suite against.
|
||||
|
||||
Minimal setup:
|
||||
- Set GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN (token1)
|
||||
Optional:
|
||||
- Set GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC (token2 / classic)
|
||||
|
||||
If the classic token is provided, the GitHub suite will run twice (once per token).
|
||||
"""
|
||||
token_1 = os.environ.get("GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN")
|
||||
# Prefer the new "classic" name, but keep backward compatibility.
|
||||
token_2 = os.environ.get("GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC")
|
||||
|
||||
tokens: list[str] = []
|
||||
if token_1:
|
||||
tokens.append(token_1)
|
||||
if token_2:
|
||||
tokens.append(token_2)
|
||||
return tokens
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=_get_github_test_tokens())
|
||||
def github_access_token(request: pytest.FixtureRequest) -> str:
|
||||
tokens = _get_github_test_tokens()
|
||||
if not tokens:
|
||||
pytest.skip(
|
||||
"Skipping GitHub tests due to missing env vars "
|
||||
"GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN and "
|
||||
"GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC"
|
||||
)
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def github_test_env_setup(
|
||||
github_access_token: str,
|
||||
) -> Generator[GitHubTestEnvSetupTuple]:
|
||||
"""
|
||||
Create a complete GitHub test environment with:
|
||||
- 3 users with email IDs from environment variables
|
||||
- GitHub credentials using ACCESS_TOKEN_GITHUB from environment
|
||||
- GitHub connector configured for testing
|
||||
- Connector-Credential pair linking them together
|
||||
|
||||
Returns:
|
||||
Tuple containing: (admin_user, test_user_1, test_user_2, github_credential, github_connector, github_cc_pair)
|
||||
"""
|
||||
# Reset all resources before setting up the test environment
|
||||
reset_all()
|
||||
|
||||
# Get user emails from environment (with fallbacks)
|
||||
admin_email = os.environ.get("GITHUB_ADMIN_EMAIL")
|
||||
test_user_1_email = os.environ.get("GITHUB_TEST_USER_1_EMAIL")
|
||||
test_user_2_email = os.environ.get("GITHUB_TEST_USER_2_EMAIL")
|
||||
|
||||
if not admin_email or not test_user_1_email or not test_user_2_email:
|
||||
pytest.skip(
|
||||
"Skipping GitHub test environment setup due to missing environment variables"
|
||||
)
|
||||
|
||||
# Create users
|
||||
admin_user: DATestUser = UserManager.create(email=admin_email)
|
||||
test_user_1: DATestUser = UserManager.create(email=test_user_1_email)
|
||||
test_user_2: DATestUser = UserManager.create(email=test_user_2_email)
|
||||
|
||||
# Create LLM provider - required for document search to work
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
# Create GitHub credentials
|
||||
github_credentials = {
|
||||
"github_access_token": github_access_token,
|
||||
}
|
||||
|
||||
github_credential: DATestCredential = CredentialManager.create(
|
||||
source=DocumentSource.GITHUB,
|
||||
credential_json=github_credentials,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create GitHub connector
|
||||
github_connector: DATestConnector = ConnectorManager.create(
|
||||
name="GitHub Test Connector",
|
||||
input_type=InputType.POLL,
|
||||
source=DocumentSource.GITHUB,
|
||||
connector_specific_config={
|
||||
"repo_owner": "permission-sync-test",
|
||||
"include_prs": True,
|
||||
"repositories": "perm-sync-test-minimal",
|
||||
"include_issues": True,
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create CC pair linking connector and credential
|
||||
github_cc_pair: DATestCCPair = CCPairManager.create(
|
||||
credential_id=github_credential.id,
|
||||
connector_id=github_connector.id,
|
||||
name="GitHub Test CC Pair",
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for initial indexing to complete
|
||||
# GitHub API operations can be slow due to rate limiting and network latency
|
||||
# Use a longer timeout for initial indexing to avoid flaky test failures
|
||||
before = datetime.now(tz=timezone.utc)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=github_cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
timeout=900,
|
||||
)
|
||||
|
||||
yield admin_user, test_user_1, test_user_2, github_credential, github_connector, github_cc_pair
|
||||
@@ -1,341 +0,0 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
from github import Github
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
from tests.integration.common_utils.document_acl import (
|
||||
get_all_connector_documents,
|
||||
)
|
||||
from tests.integration.common_utils.document_acl import (
|
||||
get_user_document_access_via_acl,
|
||||
)
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.connector_job_tests.github.conftest import (
|
||||
GitHubTestEnvSetupTuple,
|
||||
)
|
||||
from tests.integration.connector_job_tests.github.utils import GitHubManager
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission tests are enterprise only",
|
||||
)
|
||||
def test_github_private_repo_permission_sync(
|
||||
github_test_env_setup: GitHubTestEnvSetupTuple,
|
||||
) -> None:
|
||||
|
||||
(
|
||||
admin_user,
|
||||
test_user_1,
|
||||
test_user_2,
|
||||
github_credential,
|
||||
github_connector,
|
||||
github_cc_pair,
|
||||
) = github_test_env_setup
|
||||
|
||||
# Create GitHub client from credential
|
||||
github_access_token = github_credential.credential_json["github_access_token"]
|
||||
github_client = Github(github_access_token)
|
||||
github_manager = GitHubManager(github_client)
|
||||
|
||||
# Get repository configuration from connector
|
||||
repo_owner = github_connector.connector_specific_config["repo_owner"]
|
||||
repo_name = github_connector.connector_specific_config["repositories"]
|
||||
|
||||
success = github_manager.change_repository_visibility(
|
||||
repo_owner=repo_owner, repo_name=repo_name, visibility="private"
|
||||
)
|
||||
|
||||
if not success:
|
||||
pytest.fail(f"Failed to change repository {repo_owner}/{repo_name} to private")
|
||||
|
||||
# Add test-team to repository at the start
|
||||
logger.info(f"Adding test-team to repository {repo_owner}/{repo_name}")
|
||||
team_added = github_manager.add_team_to_repository(
|
||||
repo_owner=repo_owner,
|
||||
repo_name=repo_name,
|
||||
team_slug="test-team",
|
||||
permission="pull",
|
||||
)
|
||||
|
||||
if not team_added:
|
||||
logger.warning(
|
||||
f"Failed to add test-team to repository {repo_owner}/{repo_name}"
|
||||
)
|
||||
|
||||
try:
|
||||
after = datetime.now(timezone.utc)
|
||||
CCPairManager.sync(
|
||||
cc_pair=github_cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Use a longer timeout for GitHub permission sync operations
|
||||
# GitHub API operations can be slow, especially with rate limiting
|
||||
# This accounts for document sync, group sync, and vespa sync operations
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=github_cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
after=after,
|
||||
should_wait_for_group_sync=True,
|
||||
timeout=900,
|
||||
)
|
||||
|
||||
# ACL-based verification
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get all documents for this connector
|
||||
all_document_ids = get_all_connector_documents(github_cc_pair, db_session)
|
||||
|
||||
# Test access for both users using ACL verification
|
||||
accessible_docs_user1 = get_user_document_access_via_acl(
|
||||
test_user=test_user_1,
|
||||
document_ids=all_document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
accessible_docs_user2 = get_user_document_access_via_acl(
|
||||
test_user=test_user_2,
|
||||
document_ids=all_document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"test_user_1 has access to {len(accessible_docs_user1)} documents"
|
||||
)
|
||||
logger.info(
|
||||
f"test_user_2 has access to {len(accessible_docs_user2)} documents"
|
||||
)
|
||||
|
||||
# test_user_1 (part of test-team) should have access
|
||||
# test_user_2 (not part of test-team) should NOT have access
|
||||
assert len(accessible_docs_user1) > 0, (
|
||||
f"test_user_1 should have access to private repository documents. "
|
||||
f"Found {len(accessible_docs_user1)} accessible docs out of "
|
||||
f"{len(all_document_ids)} total"
|
||||
)
|
||||
assert len(accessible_docs_user2) == 0, (
|
||||
f"test_user_2 should NOT have access to private repository documents. "
|
||||
f"Found {len(accessible_docs_user2)} accessible docs out of "
|
||||
f"{len(all_document_ids)} total"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Remove test-team from repository at the end
|
||||
logger.info(f"Removing test-team from repository {repo_owner}/{repo_name}")
|
||||
team_removed = github_manager.remove_team_from_repository(
|
||||
repo_owner=repo_owner, repo_name=repo_name, team_slug="test-team"
|
||||
)
|
||||
|
||||
if not team_removed:
|
||||
logger.warning(
|
||||
f"Failed to remove test-team from repository {repo_owner}/{repo_name}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission tests are enterprise only",
|
||||
)
|
||||
def test_github_public_repo_permission_sync(
|
||||
github_test_env_setup: GitHubTestEnvSetupTuple,
|
||||
) -> None:
|
||||
"""
|
||||
Test that when a repository is changed to public, both users can access the documents.
|
||||
"""
|
||||
(
|
||||
admin_user,
|
||||
test_user_1,
|
||||
test_user_2,
|
||||
github_credential,
|
||||
github_connector,
|
||||
github_cc_pair,
|
||||
) = github_test_env_setup
|
||||
|
||||
# Create GitHub client from credential
|
||||
github_access_token = github_credential.credential_json["github_access_token"]
|
||||
github_client = Github(github_access_token)
|
||||
github_manager = GitHubManager(github_client)
|
||||
|
||||
# Get repository configuration from connector
|
||||
repo_owner = github_connector.connector_specific_config["repo_owner"]
|
||||
repo_name = github_connector.connector_specific_config["repositories"]
|
||||
|
||||
# Change repository to public
|
||||
logger.info(f"Changing repository {repo_owner}/{repo_name} to public")
|
||||
success = github_manager.change_repository_visibility(
|
||||
repo_owner=repo_owner, repo_name=repo_name, visibility="public"
|
||||
)
|
||||
|
||||
if not success:
|
||||
pytest.fail(f"Failed to change repository {repo_owner}/{repo_name} to public")
|
||||
|
||||
# Verify repository is now public
|
||||
current_visibility = github_manager.get_repository_visibility(
|
||||
repo_owner=repo_owner, repo_name=repo_name
|
||||
)
|
||||
logger.info(f"Repository {repo_owner}/{repo_name} visibility: {current_visibility}")
|
||||
assert (
|
||||
current_visibility == "public"
|
||||
), f"Repository should be public, but is {current_visibility}"
|
||||
|
||||
# Trigger sync to update permissions
|
||||
after = datetime.now(timezone.utc)
|
||||
CCPairManager.sync(
|
||||
cc_pair=github_cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for sync to complete with group sync
|
||||
# Public repositories should be accessible to all users
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=github_cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
after=after,
|
||||
should_wait_for_group_sync=True,
|
||||
timeout=900,
|
||||
)
|
||||
|
||||
# ACL-based verification
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get all documents for this connector
|
||||
all_document_ids = get_all_connector_documents(github_cc_pair, db_session)
|
||||
|
||||
# Test access for both users using ACL verification
|
||||
accessible_docs_user1 = get_user_document_access_via_acl(
|
||||
test_user=test_user_1,
|
||||
document_ids=all_document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
accessible_docs_user2 = get_user_document_access_via_acl(
|
||||
test_user=test_user_2,
|
||||
document_ids=all_document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(f"test_user_1 has access to {len(accessible_docs_user1)} documents")
|
||||
logger.info(f"test_user_2 has access to {len(accessible_docs_user2)} documents")
|
||||
|
||||
# Both users should have access to the public repository documents
|
||||
assert len(accessible_docs_user1) > 0, (
|
||||
f"test_user_1 should have access to public repository documents. "
|
||||
f"Found {len(accessible_docs_user1)} accessible docs out of "
|
||||
f"{len(all_document_ids)} total"
|
||||
)
|
||||
assert len(accessible_docs_user2) > 0, (
|
||||
f"test_user_2 should have access to public repository documents. "
|
||||
f"Found {len(accessible_docs_user2)} accessible docs out of "
|
||||
f"{len(all_document_ids)} total"
|
||||
)
|
||||
|
||||
# Verify that both users get the same results (since repo is public)
|
||||
assert len(accessible_docs_user1) == len(accessible_docs_user2), (
|
||||
f"Both users should see the same documents from public repository. "
|
||||
f"User1: {len(accessible_docs_user1)}, User2: {len(accessible_docs_user2)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission tests are enterprise only",
|
||||
)
|
||||
def test_github_internal_repo_permission_sync(
|
||||
github_test_env_setup: GitHubTestEnvSetupTuple,
|
||||
) -> None:
|
||||
"""
|
||||
Test that when a repository is changed to internal, test_user_1 has access but test_user_2 doesn't.
|
||||
Internal repositories are accessible only to organization members.
|
||||
"""
|
||||
(
|
||||
admin_user,
|
||||
test_user_1,
|
||||
test_user_2,
|
||||
github_credential,
|
||||
github_connector,
|
||||
github_cc_pair,
|
||||
) = github_test_env_setup
|
||||
|
||||
# Create GitHub client from credential
|
||||
github_access_token = github_credential.credential_json["github_access_token"]
|
||||
github_client = Github(github_access_token)
|
||||
github_manager = GitHubManager(github_client)
|
||||
|
||||
# Get repository configuration from connector
|
||||
repo_owner = github_connector.connector_specific_config["repo_owner"]
|
||||
repo_name = github_connector.connector_specific_config["repositories"]
|
||||
|
||||
# Change repository to internal
|
||||
logger.info(f"Changing repository {repo_owner}/{repo_name} to internal")
|
||||
success = github_manager.change_repository_visibility(
|
||||
repo_owner=repo_owner, repo_name=repo_name, visibility="internal"
|
||||
)
|
||||
|
||||
if not success:
|
||||
pytest.fail(f"Failed to change repository {repo_owner}/{repo_name} to internal")
|
||||
|
||||
# Verify repository is now internal
|
||||
current_visibility = github_manager.get_repository_visibility(
|
||||
repo_owner=repo_owner, repo_name=repo_name
|
||||
)
|
||||
logger.info(f"Repository {repo_owner}/{repo_name} visibility: {current_visibility}")
|
||||
assert (
|
||||
current_visibility == "internal"
|
||||
), f"Repository should be internal, but is {current_visibility}"
|
||||
|
||||
# Trigger sync to update permissions
|
||||
after = datetime.now(timezone.utc)
|
||||
CCPairManager.sync(
|
||||
cc_pair=github_cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for sync to complete with group sync
|
||||
# Internal repositories should be accessible only to organization members
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=github_cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
after=after,
|
||||
should_wait_for_group_sync=True,
|
||||
timeout=900,
|
||||
)
|
||||
|
||||
# ACL-based verification
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get all documents for this connector
|
||||
all_document_ids = get_all_connector_documents(github_cc_pair, db_session)
|
||||
|
||||
# Test access for both users using ACL verification
|
||||
accessible_docs_user1 = get_user_document_access_via_acl(
|
||||
test_user=test_user_1,
|
||||
document_ids=all_document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
accessible_docs_user2 = get_user_document_access_via_acl(
|
||||
test_user=test_user_2,
|
||||
document_ids=all_document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(f"test_user_1 has access to {len(accessible_docs_user1)} documents")
|
||||
logger.info(f"test_user_2 has access to {len(accessible_docs_user2)} documents")
|
||||
|
||||
# For internal repositories:
|
||||
# - test_user_1 should have access (assuming they're part of the organization)
|
||||
# - test_user_2 should NOT have access (assuming they're not part of the organization)
|
||||
assert len(accessible_docs_user1) > 0, (
|
||||
f"test_user_1 should have access to internal repository documents (organization member). "
|
||||
f"Found {len(accessible_docs_user1)} accessible docs out of "
|
||||
f"{len(all_document_ids)} total"
|
||||
)
|
||||
assert len(accessible_docs_user2) == 0, (
|
||||
f"test_user_2 should NOT have access to internal repository documents (not organization member). "
|
||||
f"Found {len(accessible_docs_user2)} accessible docs out of "
|
||||
f"{len(all_document_ids)} total"
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user