Compare commits

..

36 Commits

Author SHA1 Message Date
Dane Urban
6985661dcd . 2026-03-26 10:33:56 +11:00
Dane Urban
3e2a10ce9d . 2026-03-26 10:26:58 +11:00
Dane Urban
389eb6c281 Update 2026-03-26 09:59:26 +11:00
Dane Urban
ff88d1886b Update 2026-03-26 09:57:24 +11:00
Dane Urban
18dac2ba71 . 2026-03-24 21:40:43 +11:00
Dane Urban
96cd5bb751 . 2026-03-24 21:31:21 +11:00
Dane Urban
30a7c40c55 . 2026-03-24 21:24:44 +11:00
Dane Urban
641fb61c45 . 2026-03-24 21:23:36 +11:00
Dane Urban
6f8d9cfdd7 . 2026-03-24 21:23:36 +11:00
Dane Urban
2784e42cfe . 2026-03-24 21:23:36 +11:00
Dane Urban
4f5fc65428 . 2026-03-24 21:23:36 +11:00
Dane Urban
8fcdd3a3fb . 2026-03-24 21:23:36 +11:00
Dane Urban
3b7c53aeb1 Adapter refactor 2026-03-24 21:23:36 +11:00
Dane Urban
ea58e82aed Support streaming via document adapter 2026-03-24 21:23:36 +11:00
Dane Urban
bd35585785 Add extra tests 2026-03-24 21:23:34 +11:00
Dane Urban
cf9bd7e511 . 2026-03-24 21:21:01 +11:00
Dane Urban
b5dd17a371 . 2026-03-24 21:19:38 +11:00
Dane Urban
d62d0c1864 . 2026-03-24 21:14:52 +11:00
Dane Urban
2c92742c62 . 2026-03-24 21:06:48 +11:00
Dane Urban
1e1402e4f1 . 2026-03-24 21:06:48 +11:00
Dane Urban
440818a082 Max chunks 2026-03-24 21:06:48 +11:00
Dane Urban
bd9f40d1c1 . 2026-03-24 21:06:46 +11:00
Dane Urban
c85e090c13 . 2026-03-24 21:06:23 +11:00
Dane Urban
d72df59063 . 2026-03-24 21:04:07 +11:00
Dane Urban
867442bc54 . 2026-03-24 21:04:07 +11:00
Dane Urban
f752761e46 Fix comment 2026-03-24 21:04:07 +11:00
Dane Urban
a760d1cf33 Add extra tests 2026-03-24 21:04:07 +11:00
Dane Urban
acffd55ce4 Add comments 2026-03-24 21:04:07 +11:00
Dane Urban
3a4be4a7d9 Remove restriction comment 2026-03-24 21:04:07 +11:00
Dane Urban
7c0e7eddbd mypy fixes 2026-03-24 21:04:07 +11:00
Dane Urban
2e5763c9ab . 2026-03-24 21:04:06 +11:00
Dane Urban
5c45345521 Vespa change 2026-03-24 21:04:06 +11:00
Dane Urban
0665f31a7d Open-search iterable refactor 2026-03-24 21:04:06 +11:00
Dane Urban
17442ed2d0 . 2026-03-24 21:04:05 +11:00
Dane Urban
5b0c2f3c18 . 2026-03-24 21:03:53 +11:00
Dane Urban
cff564eb6a Add tests 2026-03-24 20:59:18 +11:00
388 changed files with 5795 additions and 18678 deletions

View File

@@ -6,4 +6,3 @@
3134e5f840c12c8f32613ce520101a047c89dcc2 # refactor(whitespace): rm temporary react fragments (#7161)
ed3f72bc75f3e3a9ae9e4d8cd38278f9c97e78b4 # refactor(whitespace): rm react fragment #7190
7b927e79c25f4ddfd18a067f489e122acd2c89de # chore(format): format files where `ruff` and `black` agree (#9339)

View File

@@ -615,7 +615,6 @@ jobs:
tags: |
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'craft-latest' || '' }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
@@ -1264,6 +1263,8 @@ jobs:
latest=false
tags: |
type=raw,value=craft-latest
# TODO: Consider aligning craft-latest tags with regular backend builds (e.g., latest, edge, beta)
# to keep tagging strategy consistent across all backend images
- name: Create and push manifest
env:
@@ -1487,7 +1488,6 @@ jobs:
tags: |
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'craft-latest' || '' }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}

View File

@@ -7,15 +7,6 @@ on:
merge_group:
pull_request:
branches: [main]
paths:
- "backend/**"
- "pyproject.toml"
- "uv.lock"
- ".github/workflows/pr-external-dependency-unit-tests.yml"
- ".github/actions/setup-python-and-install-dependencies/**"
- ".github/actions/setup-playwright/**"
- "deployment/docker_compose/docker-compose.yml"
- "deployment/docker_compose/docker-compose.dev.yml"
push:
tags:
- "v*.*.*"

View File

@@ -7,13 +7,6 @@ on:
merge_group:
pull_request:
branches: [main]
paths:
- "backend/**"
- "pyproject.toml"
- "uv.lock"
- ".github/workflows/pr-python-connector-tests.yml"
- ".github/actions/setup-python-and-install-dependencies/**"
- ".github/actions/setup-playwright/**"
push:
tags:
- "v*.*.*"

View File

@@ -1,64 +0,0 @@
{
"labels": [],
"comment": "",
"fixWithAI": true,
"hideFooter": false,
"strictness": 3,
"statusCheck": true,
"commentTypes": [
"logic",
"syntax",
"style"
],
"instructions": "",
"disabledLabels": [],
"excludeAuthors": [
"dependabot[bot]",
"renovate[bot]"
],
"ignoreKeywords": "",
"ignorePatterns": "",
"includeAuthors": [],
"summarySection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"excludeBranches": [],
"fileChangeLimit": 300,
"includeBranches": [],
"includeKeywords": "",
"triggerOnUpdates": true,
"updateExistingSummaryComment": true,
"updateSummaryOnly": false,
"issuesTableSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"statusCommentsEnabled": true,
"confidenceScoreSection": {
"included": true,
"collapsible": false
},
"sequenceDiagramSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"shouldUpdateDescription": false,
"rules": [
{
"scope": ["web/**"],
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
},
{
"scope": ["web/**"],
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
},
{
"scope": ["backend/**/*.py"],
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
}
]
}

View File

@@ -1,57 +0,0 @@
[
{
"scope": [],
"path": "contributing_guides/best_practices.md",
"description": "Best practices for contributing to the codebase"
},
{
"scope": ["web/**"],
"path": "web/AGENTS.md",
"description": "Frontend coding standards for the web directory"
},
{
"scope": ["web/**"],
"path": "web/tests/README.md",
"description": "Frontend testing guide and conventions"
},
{
"scope": ["web/**"],
"path": "web/CLAUDE.md",
"description": "Single source of truth for frontend coding standards"
},
{
"scope": ["web/**"],
"path": "web/lib/opal/README.md",
"description": "Opal component library usage guide"
},
{
"scope": ["backend/**"],
"path": "backend/tests/README.md",
"description": "Backend testing guide covering all 4 test types, fixtures, and conventions"
},
{
"scope": ["backend/onyx/connectors/**"],
"path": "backend/onyx/connectors/README.md",
"description": "Connector development guide covering design, interfaces, and required changes"
},
{
"scope": [],
"path": "CLAUDE.md",
"description": "Project instructions and coding standards"
},
{
"scope": [],
"path": "backend/alembic/README.md",
"description": "Migration guidance, including multi-tenant migration behavior"
},
{
"scope": [],
"path": "deployment/helm/charts/onyx/values-lite.yaml",
"description": "Lite deployment Helm values and service assumptions"
},
{
"scope": [],
"path": "deployment/docker_compose/docker-compose.onyx-lite.yml",
"description": "Lite deployment Docker Compose overlay and disabled service behavior"
}
]

View File

@@ -1,29 +0,0 @@
# Greptile Review Rules
## Type Annotations
Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code.
## Best Practices
Use `contributing_guides/best_practices.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly.
## TODOs
Whenever a TODO is added, there must always be an associated name or ticket with that TODO in the style of `TODO(name): ...` or `TODO(1234): ...`
## Debugging Code
Remove temporary debugging code before merging to production, especially tenant-specific debugging logs.
## Hardcoded Booleans
When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant.
## Multi-tenant vs Single-tenant
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
## Full vs Lite Deployments
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.

12
.vscode/launch.json vendored
View File

@@ -117,8 +117,7 @@
"presentation": {
"group": "2"
},
"consoleTitle": "API Server Console",
"justMyCode": false
"consoleTitle": "API Server Console"
},
{
"name": "Slack Bot",
@@ -269,8 +268,7 @@
"presentation": {
"group": "2"
},
"consoleTitle": "Celery heavy Console",
"justMyCode": false
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery kg_processing",
@@ -357,8 +355,7 @@
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user_file_processing Console",
"justMyCode": false
"consoleTitle": "Celery user_file_processing Console"
},
{
"name": "Celery docfetching",
@@ -416,8 +413,7 @@
"presentation": {
"group": "2"
},
"consoleTitle": "Celery docprocessing Console",
"justMyCode": false
"consoleTitle": "Celery docprocessing Console"
},
{
"name": "Celery beat",

View File

@@ -1,35 +0,0 @@
"""remove voice_provider deleted column
Revision ID: 1d78c0ca7853
Revises: a3f8b2c1d4e5
Create Date: 2026-03-26 11:30:53.883127
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "1d78c0ca7853"
down_revision = "a3f8b2c1d4e5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Hard-delete any soft-deleted rows before dropping the column
op.execute("DELETE FROM voice_provider WHERE deleted = true")
op.drop_column("voice_provider", "deleted")
def downgrade() -> None:
op.add_column(
"voice_provider",
sa.Column(
"deleted",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)

View File

@@ -1,109 +0,0 @@
"""group_permissions_phase1
Revision ID: 25a5501dc766
Revises: b728689f45b1
Create Date: 2026-03-23 11:41:25.557442
"""
from alembic import op
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
from onyx.db.enums import AccountType
from onyx.db.enums import GrantSource
from onyx.db.enums import Permission
# revision identifiers, used by Alembic.
revision = "25a5501dc766"
down_revision = "b728689f45b1"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1. Add account_type column to user table (nullable for now).
# TODO(subash): backfill account_type for existing rows and add NOT NULL.
op.add_column(
"user",
sa.Column(
"account_type",
sa.Enum(AccountType, native_enum=False),
nullable=True,
),
)
# 2. Add is_default column to user_group table
op.add_column(
"user_group",
sa.Column(
"is_default",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# 3. Create permission_grant table
op.create_table(
"permission_grant",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("group_id", sa.Integer(), nullable=False),
sa.Column(
"permission",
sa.Enum(Permission, native_enum=False),
nullable=False,
),
sa.Column(
"grant_source",
sa.Enum(GrantSource, native_enum=False),
nullable=False,
),
sa.Column(
"granted_by",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column(
"granted_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"is_deleted",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["group_id"],
["user_group.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["granted_by"],
["user.id"],
ondelete="SET NULL",
),
sa.UniqueConstraint(
"group_id", "permission", name="uq_permission_grant_group_permission"
),
)
# 4. Index on user__user_group(user_id) — existing composite PK
# has user_group_id as leading column; user-filtered queries need this
op.create_index(
"ix_user__user_group_user_id",
"user__user_group",
["user_id"],
)
def downgrade() -> None:
op.drop_index("ix_user__user_group_user_id", table_name="user__user_group")
op.drop_table("permission_grant")
op.drop_column("user_group", "is_default")
op.drop_column("user", "account_type")

View File

@@ -1,36 +0,0 @@
"""add preferred_response_id and model_display_name to chat_message
Revision ID: a3f8b2c1d4e5
Create Date: 2026-03-22
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a3f8b2c1d4e5"
down_revision = "25a5501dc766"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column(
"preferred_response_id",
sa.Integer(),
sa.ForeignKey("chat_message.id", ondelete="SET NULL"),
nullable=True,
),
)
op.add_column(
"chat_message",
sa.Column("model_display_name", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_message", "model_display_name")
op.drop_column("chat_message", "preferred_response_id")

View File

@@ -25,13 +25,10 @@ from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import TENANT_ID_PREFIX
# Maximum tenants to provision in a single task run.
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
_MAX_TENANTS_PER_RUN = 5
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
# Soft time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
# Hard time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
@shared_task(
@@ -88,26 +85,9 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
f"To provision: {tenants_to_provision}"
)
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
if batch_size < tenants_to_provision:
task_logger.info(
f"Capping batch to {batch_size} "
f"(need {tenants_to_provision}, will catch up next cycle)"
)
provisioned = 0
for i in range(batch_size):
task_logger.info(f"Provisioning tenant {i + 1}/{batch_size}")
try:
if pre_provision_tenant():
provisioned += 1
except Exception:
task_logger.exception(
f"Failed to provision tenant {i + 1}/{batch_size}, "
"continuing with remaining tenants"
)
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
# just provision one tenant each time we run this ... increase if needed.
if tenants_to_provision > 0:
pre_provision_tenant()
except Exception:
task_logger.exception("Error in check_available_tenants task")
@@ -121,13 +101,11 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
)
def pre_provision_tenant() -> bool:
def pre_provision_tenant() -> None:
"""
Pre-provision a new tenant and store it in the NewAvailableTenant table.
This function fully sets up the tenant with all necessary configurations,
so it's ready to be assigned to a user immediately.
Returns True if a tenant was successfully provisioned, False otherwise.
"""
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
# rather than inside this function
@@ -140,10 +118,10 @@ def pre_provision_tenant() -> bool:
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
if not lock_provision.acquire(blocking=False):
task_logger.warning(
"Skipping pre_provision_tenant — could not acquire provision lock"
task_logger.debug(
"Skipping pre_provision_tenant task because it is already running"
)
return False
return
tenant_id: str | None = None
try:
@@ -183,7 +161,6 @@ def pre_provision_tenant() -> bool:
db_session.add(new_tenant)
db_session.commit()
task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}")
return True
except Exception:
db_session.rollback()
task_logger.error(
@@ -207,7 +184,6 @@ def pre_provision_tenant() -> bool:
asyncio.run(rollback_tenant_provisioning(tenant_id))
except Exception:
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
return False
finally:
try:
lock_provision.release()

View File

@@ -115,14 +115,8 @@ def fetch_user_group_token_rate_limits_for_user(
ordered: bool = True,
get_editable: bool = True,
) -> Sequence[TokenRateLimit]:
stmt = (
select(TokenRateLimit)
.join(
TokenRateLimit__UserGroup,
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
)
.where(TokenRateLimit__UserGroup.user_group_id == group_id)
)
stmt = select(TokenRateLimit)
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
stmt = _add_user_filters(stmt, user, get_editable)
if enabled_only:

View File

@@ -800,33 +800,6 @@ def update_user_group(
return db_user_group
def rename_user_group(
db_session: Session,
user_group_id: int,
new_name: str,
) -> UserGroup:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)
if db_user_group is None:
raise ValueError(f"UserGroup with id '{user_group_id}' not found")
_check_user_group_is_modifiable(db_user_group)
db_user_group.name = new_name
db_user_group.time_last_modified_by_user = func.now()
# CC pair documents in Vespa contain the group name, so we need to
# trigger a sync to update them with the new name.
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
if not DISABLE_VECTOR_DB:
db_user_group.is_up_to_date = False
db_session.commit()
return db_user_group
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)

View File

@@ -250,24 +250,20 @@ def _get_sharepoint_list_item_id(drive_item: DriveItem) -> str | None:
raise e
def _is_public_item(
drive_item: DriveItem,
treat_sharing_link_as_public: bool = False,
) -> bool:
if not treat_sharing_link_as_public:
return False
def _is_public_item(drive_item: DriveItem) -> bool:
is_public = False
try:
permissions = sleep_and_retry(
drive_item.permissions.get_all(page_loaded=lambda _: None), "is_public_item"
)
for permission in permissions:
if permission.link and permission.link.scope in (
"anonymous",
"organization",
if permission.link and (
permission.link.scope == "anonymous"
or permission.link.scope == "organization"
):
return True
return False
is_public = True
break
return is_public
except Exception as e:
logger.error(f"Failed to check if item {drive_item.id} is public: {e}")
return False
@@ -508,7 +504,6 @@ def get_external_access_from_sharepoint(
drive_item: DriveItem | None,
site_page: dict[str, Any] | None,
add_prefix: bool = False,
treat_sharing_link_as_public: bool = False,
) -> ExternalAccess:
"""
Get external access information from SharePoint.
@@ -568,7 +563,8 @@ def get_external_access_from_sharepoint(
)
if drive_item and drive_name:
is_public = _is_public_item(drive_item, treat_sharing_link_as_public)
# Here we check if the item have have any public links, if so we return early
is_public = _is_public_item(drive_item)
if is_public:
logger.info(f"Item {drive_item.id} is public")
return ExternalAccess(

View File

@@ -44,21 +44,19 @@ def _run_single_search(
user: User,
db_session: Session,
num_hits: int | None = None,
hybrid_alpha: float | None = None,
) -> list[InferenceChunk]:
"""Execute a single search query and return chunks."""
chunk_search_request = ChunkSearchRequest(
query=query,
user_selected_filters=filters,
limit=num_hits,
hybrid_alpha=hybrid_alpha,
)
return search_pipeline(
chunk_search_request=chunk_search_request,
document_index=document_index,
user=user,
persona_search_info=None,
persona=None, # No persona for direct search
db_session=db_session,
)
@@ -76,7 +74,7 @@ def stream_search_query(
Core search function that yields streaming packets.
Used by both streaming and non-streaming endpoints.
"""
# Get document index.
# Get document index
search_settings = get_current_search_settings(db_session)
# This flow is for search so we do not get all indices.
document_index = get_default_document_index(search_settings, None, db_session)
@@ -121,7 +119,6 @@ def stream_search_query(
user=user,
db_session=db_session,
num_hits=request.num_hits,
hybrid_alpha=request.hybrid_alpha,
)
else:
# Multiple queries - run in parallel and merge with RRF
@@ -136,7 +133,6 @@ def stream_search_query(
user,
db_session,
request.num_hits,
request.hybrid_alpha,
),
)
for query in all_executed_queries

View File

@@ -27,17 +27,15 @@ class SearchFlowClassificationResponse(BaseModel):
is_search_flow: bool
# NOTE: This model is used for the core flow of the Onyx application, any
# changes to it should be reviewed and approved by an experienced team member.
# It is very important to 1. avoid bloat and 2. that this remains backwards
# compatible across versions.
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
class SendSearchQueryRequest(BaseModel):
search_query: str
filters: BaseFilters | None = None
num_docs_fed_to_llm_selection: int | None = None
run_query_expansion: bool = False
num_hits: int = 30
hybrid_alpha: float | None = None
include_content: bool = False
stream: bool = False

View File

@@ -20,7 +20,6 @@ from ee.onyx.server.query_and_chat.models import SearchQueryResponse
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
from onyx.auth.users import current_user
from onyx.configs.app_configs import ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import User
@@ -68,10 +67,8 @@ def search_flow_classification(
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
# NOTE: This endpoint is used for the core flow of the Onyx application, any
# changes to it should be reviewed and approved by an experienced team member.
# It is very important to 1. avoid bloat and 2. that this remains backwards
# compatible across versions.
# NOTE: This endpoint is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
@router.post(
"/send-search-message",
response_model=None,
@@ -83,19 +80,13 @@ def handle_send_search_message(
db_session: Session = Depends(get_session),
) -> StreamingResponse | SearchFullResponse:
"""
Executes a search query with optional streaming.
Execute a search query with optional streaming.
If hybrid_alpha is unset and ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH
is True, executes pure keyword search.
Returns:
StreamingResponse with SSE if stream=True, otherwise SearchFullResponse.
When stream=True: Returns StreamingResponse with SSE
When stream=False: Returns SearchFullResponse
"""
logger.debug(f"Received search query: {request.search_query}")
if request.hybrid_alpha is None and ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH:
request.hybrid_alpha = 0.0
# Non-streaming path
if not request.stream:
try:

View File

@@ -4,7 +4,6 @@ from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.persona import update_persona_access
from ee.onyx.db.user_group import add_users_to_user_group
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
from ee.onyx.db.user_group import fetch_user_group
@@ -12,16 +11,13 @@ from ee.onyx.db.user_group import fetch_user_groups
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.db.user_group import insert_user_group
from ee.onyx.db.user_group import prepare_user_group_for_deletion
from ee.onyx.db.user_group import rename_user_group
from ee.onyx.db.user_group import update_user_curator_relationship
from ee.onyx.db.user_group import update_user_group
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
from ee.onyx.server.user_group.models import SetCuratorRequest
from ee.onyx.server.user_group.models import UpdateGroupAgentsRequest
from ee.onyx.server.user_group.models import UserGroup
from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupRename
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
@@ -31,9 +27,6 @@ from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.models import UserRole
from onyx.db.persona import get_persona_by_id
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -94,32 +87,6 @@ def create_user_group(
return UserGroup.from_model(db_user_group)
@router.patch("/admin/user-group/rename")
def rename_user_group_endpoint(
rename_request: UserGroupRename,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
try:
return UserGroup.from_model(
rename_user_group(
db_session=db_session,
user_group_id=rename_request.id,
new_name=rename_request.name,
)
)
except IntegrityError:
raise OnyxError(
OnyxErrorCode.DUPLICATE_RESOURCE,
f"User group with name '{rename_request.name}' already exists.",
)
except ValueError as e:
msg = str(e)
if "not found" in msg.lower():
raise OnyxError(OnyxErrorCode.NOT_FOUND, msg)
raise OnyxError(OnyxErrorCode.CONFLICT, msg)
@router.patch("/admin/user-group/{user_group_id}")
def patch_user_group(
user_group_id: int,
@@ -194,38 +161,3 @@ def delete_user_group(
user_group = fetch_user_group(db_session, user_group_id)
if user_group:
db_delete_user_group(db_session, user_group)
@router.patch("/admin/user-group/{user_group_id}/agents")
def update_group_agents(
user_group_id: int,
request: UpdateGroupAgentsRequest,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
for agent_id in request.added_agent_ids:
persona = get_persona_by_id(
persona_id=agent_id, user=user, db_session=db_session
)
current_group_ids = [g.id for g in persona.groups]
if user_group_id not in current_group_ids:
update_persona_access(
persona_id=agent_id,
creator_user_id=user.id,
db_session=db_session,
group_ids=current_group_ids + [user_group_id],
)
for agent_id in request.removed_agent_ids:
persona = get_persona_by_id(
persona_id=agent_id, user=user, db_session=db_session
)
current_group_ids = [g.id for g in persona.groups]
update_persona_access(
persona_id=agent_id,
creator_user_id=user.id,
db_session=db_session,
group_ids=[gid for gid in current_group_ids if gid != user_group_id],
)
db_session.commit()

View File

@@ -104,16 +104,6 @@ class AddUsersToUserGroupRequest(BaseModel):
user_ids: list[UUID]
class UserGroupRename(BaseModel):
id: int
name: str
class SetCuratorRequest(BaseModel):
user_id: UUID
is_curator: bool
class UpdateGroupAgentsRequest(BaseModel):
added_agent_ids: list[int]
removed_agent_ids: list[int]

View File

@@ -13,14 +13,6 @@ from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
from onyx.server.metrics.metrics_server import start_metrics_server
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -42,8 +34,6 @@ def on_task_prerun(
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
on_celery_task_prerun(task_id, task)
on_indexing_task_prerun(task_id, task, kwargs)
@signals.task_postrun.connect
@@ -58,36 +48,6 @@ def on_task_postrun(
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
on_celery_task_postrun(task_id, task, state)
on_indexing_task_postrun(task_id, task, kwargs, state)
@signals.task_retry.connect
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
# task_retry signal doesn't pass task_id in kwargs; get it from
# the sender (the task instance) via sender.request.id.
task_id = getattr(getattr(sender, "request", None), "id", None)
on_celery_task_retry(task_id, sender)
@signals.task_revoked.connect
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
task_name = getattr(sender, "name", None) or str(sender)
on_celery_task_revoked(kwargs.get("task_id"), task_name)
@signals.task_rejected.connect
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
# task_rejected sends the Consumer as sender, not the task instance.
# The task name must be extracted from the Celery message headers.
message = kwargs.get("message")
task_name: str | None = None
if message is not None:
headers = getattr(message, "headers", None) or {}
task_name = headers.get("task")
if task_name is None:
task_name = "unknown"
on_celery_task_rejected(None, task_name)
@celeryd_init.connect
@@ -116,7 +76,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
start_metrics_server("docfetching")
app_base.on_worker_ready(sender, **kwargs)

View File

@@ -14,14 +14,6 @@ from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
from onyx.server.metrics.metrics_server import start_metrics_server
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -43,8 +35,6 @@ def on_task_prerun(
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
on_celery_task_prerun(task_id, task)
on_indexing_task_prerun(task_id, task, kwargs)
@signals.task_postrun.connect
@@ -59,36 +49,6 @@ def on_task_postrun(
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
on_celery_task_postrun(task_id, task, state)
on_indexing_task_postrun(task_id, task, kwargs, state)
@signals.task_retry.connect
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
# task_retry signal doesn't pass task_id in kwargs; get it from
# the sender (the task instance) via sender.request.id.
task_id = getattr(getattr(sender, "request", None), "id", None)
on_celery_task_retry(task_id, sender)
@signals.task_revoked.connect
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
task_name = getattr(sender, "name", None) or str(sender)
on_celery_task_revoked(kwargs.get("task_id"), task_name)
@signals.task_rejected.connect
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
# task_rejected sends the Consumer as sender, not the task instance.
# The task name must be extracted from the Celery message headers.
message = kwargs.get("message")
task_name: str | None = None
if message is not None:
headers = getattr(message, "headers", None) or {}
task_name = headers.get("task")
if task_name is None:
task_name = "unknown"
on_celery_task_rejected(None, task_name)
@celeryd_init.connect
@@ -122,7 +82,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
start_metrics_server("docprocessing")
app_base.on_worker_ready(sender, **kwargs)
@@ -131,12 +90,6 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
# Note: worker_process_init only fires in prefork pool mode. Docprocessing uses
# worker_pool="threads" (see configs/docprocessing.py), so this handler is
# effectively a no-op in normal operation. It remains as a safety net in case
# the pool type is ever changed to prefork. Prometheus metrics are safe in
# thread-pool mode since all threads share the same process memory and can
# update the same Counter/Gauge/Histogram objects directly.
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
SqlEngine.reset_engine()

View File

@@ -54,14 +54,8 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
app_base.on_celeryd_init(sender, conf, **kwargs)
# Set by on_worker_init so on_worker_ready knows whether to start the server.
_prometheus_collectors_ok: bool = False
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
global _prometheus_collectors_ok
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
@@ -71,8 +65,6 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
_prometheus_collectors_ok = _setup_prometheus_collectors(sender)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
@@ -80,37 +72,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
app_base.on_secondary_worker_init(sender, **kwargs)
def _setup_prometheus_collectors(sender: Any) -> bool:
"""Register Prometheus collectors that need Redis/DB access.
Passes the Celery app so the queue depth collector can obtain a fresh
broker Redis client on each scrape (rather than holding a stale reference).
Returns True if registration succeeded, False otherwise.
"""
try:
from onyx.server.metrics.indexing_pipeline_setup import (
setup_indexing_pipeline_metrics,
)
setup_indexing_pipeline_metrics(sender.app)
logger.info("Prometheus indexing pipeline collectors registered")
return True
except Exception:
logger.exception("Failed to register Prometheus indexing pipeline collectors")
return False
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
if _prometheus_collectors_ok:
from onyx.server.metrics.metrics_server import start_metrics_server
start_metrics_server("monitoring")
else:
logger.warning(
"Skipping Prometheus metrics server — collector registration failed"
)
app_base.on_worker_ready(sender, **kwargs)

View File

@@ -8,7 +8,6 @@ from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import InMemoryChatFile
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import Packet
@@ -36,13 +35,7 @@ class CreateChatSessionID(BaseModel):
chat_session_id: UUID
AnswerStreamPart = (
Packet
| MessageResponseIDInfo
| MultiModelMessageResponseIDInfo
| StreamingError
| CreateChatSessionID
)
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
AnswerStream = Iterator[AnswerStreamPart]

View File

@@ -59,7 +59,6 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@@ -69,19 +68,11 @@ from onyx.db.models import UserFile
from onyx.db.projects import get_user_files_from_project
from onyx.db.tools import get_tools
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import log_onyx_error
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import verify_user_files
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingPayload
from onyx.hooks.points.query_processing import QueryProcessingResponse
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
@@ -433,28 +424,6 @@ def determine_search_params(
)
def _resolve_query_processing_hook_result(
hook_result: QueryProcessingResponse | HookSkipped | HookSoftFailed,
message_text: str,
) -> str:
"""Apply the Query Processing hook result to the message text.
Returns the (possibly rewritten) message text, or raises OnyxError with
QUERY_REJECTED if the hook signals rejection (query is null or empty).
HookSkipped and HookSoftFailed are pass-throughs — the original text is
returned unchanged.
"""
if isinstance(hook_result, (HookSkipped, HookSoftFailed)):
return message_text
if not (hook_result.query and hook_result.query.strip()):
raise OnyxError(
OnyxErrorCode.QUERY_REJECTED,
hook_result.rejection_message
or "The hook extension for query processing did not return a valid query. No rejection reason was provided.",
)
return hook_result.query.strip()
def handle_stream_message_objects(
new_msg_req: SendMessageRequest,
user: User,
@@ -505,24 +474,16 @@ def handle_stream_message_objects(
db_session=db_session,
)
yield CreateChatSessionID(chat_session_id=chat_session.id)
chat_session = get_chat_session_by_id(
chat_session_id=chat_session.id,
user_id=user_id,
db_session=db_session,
eager_load_persona=True,
)
else:
chat_session = get_chat_session_by_id(
chat_session_id=new_msg_req.chat_session_id,
user_id=user_id,
db_session=db_session,
eager_load_persona=True,
)
persona = chat_session.persona
message_text = new_msg_req.message
user_identity = LLMUserIdentity(
user_id=llm_user_identifier, session_id=str(chat_session.id)
)
@@ -614,28 +575,6 @@ def handle_stream_message_objects(
if parent_message.message_type == MessageType.USER:
user_message = parent_message
else:
# New message — run the Query Processing hook before saving to DB.
# Skipped on regeneration: the message already exists and was accepted previously.
# Skip the hook for empty/whitespace-only messages — no meaningful query
# to process, and SendMessageRequest.message has no min_length guard.
if message_text.strip():
hook_result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=QueryProcessingPayload(
query=message_text,
# Pass None for anonymous users or authenticated users without an email
# (e.g. some SSO flows). QueryProcessingPayload.user_email is str | None,
# so None is accepted and serialised as null in both cases.
user_email=None if user.is_anonymous else user.email,
chat_session_id=str(chat_session.id),
).model_dump(),
response_type=QueryProcessingResponse,
)
message_text = _resolve_query_processing_hook_result(
hook_result, message_text
)
user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=parent_message,
@@ -975,17 +914,6 @@ def handle_stream_message_objects(
state_container=state_container,
)
except OnyxError as e:
if e.error_code is not OnyxErrorCode.QUERY_REJECTED:
log_onyx_error(e)
yield StreamingError(
error=e.detail,
error_code=e.error_code.code,
is_retryable=e.status_code >= 500,
)
db_session.rollback()
return
except ValueError as e:
logger.exception("Failed to process chat message.")

View File

@@ -332,10 +332,6 @@ OPENSEARCH_INDEX_NUM_REPLICAS: int | None = (
if os.environ.get("OPENSEARCH_INDEX_NUM_REPLICAS", None) is not None
else None
)
ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH = (
os.environ.get("ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH", "").lower()
== "true"
)
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
# NOTE: this is used if and only if the vespa config server is accessible via a
@@ -791,6 +787,10 @@ MINI_CHUNK_SIZE = 150
# This is the number of regular chunks per large chunk
LARGE_CHUNK_RATIO = 4
# The maximum number of chunks that can be held for 1 document processing batch
# The purpose of this is to set an upper bound on memory usage
MAX_CHUNKS_PER_DOC_BATCH = int(os.environ.get("MAX_CHUNKS_PER_DOC_BATCH") or 1000)
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"

View File

@@ -24,11 +24,11 @@ CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1)
LLM_SOCKET_READ_TIMEOUT = int(
os.environ.get("LLM_SOCKET_READ_TIMEOUT") or "60"
) # 60 seconds
# Weighting factor between vector and keyword Search; 1 for completely vector
# search, 0 for keyword. Enforces a valid range of [0, 1]. A supplied value from
# the env outside of this range will be clipped to the respective end of the
# range. Defaults to 0.5.
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5)))
HYBRID_ALPHA_KEYWORD = max(
0, min(1, float(os.environ.get("HYBRID_ALPHA_KEYWORD") or 0.4))
)
# Weighting factor between Title and Content of documents during search, 1 for completely
# Title based. Default heavily favors Content because Title is also included at the top of
# Content. This is to avoid cases where the Content is very relevant but it may not be clear

View File

@@ -1,192 +0,0 @@
from __future__ import annotations
import logging
import re
from typing import Any
from urllib.parse import urlparse
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rl_requests,
)
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
logger = logging.getLogger(__name__)
# Requests timeout in seconds.
_CANVAS_CALL_TIMEOUT: int = 30
_CANVAS_API_VERSION: str = "/api/v1"
# Matches the "next" URL in a Canvas Link header, e.g.:
# <https://canvas.example.com/api/v1/courses?page=2>; rel="next"
# Captures the URL inside the angle brackets.
_NEXT_LINK_PATTERN: re.Pattern[str] = re.compile(r'<([^>]+)>;\s*rel="next"')
_STATUS_TO_ERROR_CODE: dict[int, OnyxErrorCode] = {
401: OnyxErrorCode.CREDENTIAL_EXPIRED,
403: OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
404: OnyxErrorCode.BAD_GATEWAY,
429: OnyxErrorCode.RATE_LIMITED,
}
def _error_code_for_status(status_code: int) -> OnyxErrorCode:
"""Map an HTTP status code to the appropriate OnyxErrorCode.
Expects a >= 400 status code. Known codes (401, 403, 404, 429) are
mapped to specific error codes; all other codes (unrecognised 4xx
and 5xx) map to BAD_GATEWAY as unexpected upstream errors.
"""
if status_code in _STATUS_TO_ERROR_CODE:
return _STATUS_TO_ERROR_CODE[status_code]
return OnyxErrorCode.BAD_GATEWAY
class CanvasApiClient:
def __init__(
self,
bearer_token: str,
canvas_base_url: str,
) -> None:
parsed_base = urlparse(canvas_base_url)
if not parsed_base.hostname:
raise ValueError("canvas_base_url must include a valid host")
if parsed_base.scheme != "https":
raise ValueError("canvas_base_url must use https")
self._bearer_token = bearer_token
self.base_url = (
canvas_base_url.rstrip("/").removesuffix(_CANVAS_API_VERSION)
+ _CANVAS_API_VERSION
)
# Hostname is already validated above; reuse parsed_base instead
# of re-parsing. Used by _parse_next_link to validate pagination URLs.
self._expected_host: str = parsed_base.hostname
def get(
self,
endpoint: str = "",
params: dict[str, Any] | None = None,
full_url: str | None = None,
) -> tuple[Any, str | None]:
"""Make a GET request to the Canvas API.
Returns a tuple of (json_body, next_url).
next_url is parsed from the Link header and is None if there are no more pages.
If full_url is provided, it is used directly (for following pagination links).
Security note: full_url must only be set to values returned by
``_parse_next_link``, which validates the host against the configured
Canvas base URL. Passing an arbitrary URL would leak the bearer token.
"""
# full_url is used when following pagination (Canvas returns the
# next-page URL in the Link header). For the first request we build
# the URL from the endpoint name instead.
url = full_url if full_url else self._build_url(endpoint)
headers = self._build_headers()
response = rl_requests.get(
url,
headers=headers,
params=params if not full_url else None,
timeout=_CANVAS_CALL_TIMEOUT,
)
try:
response_json = response.json()
except ValueError as e:
if response.status_code < 300:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
detail=f"Invalid JSON in Canvas response: {e}",
)
logger.warning(
"Failed to parse JSON from Canvas error response (status=%d): %s",
response.status_code,
e,
)
response_json = {}
if response.status_code >= 400:
# Try to extract the most specific error message from the
# Canvas response body. Canvas uses three different shapes
# depending on the endpoint and error type:
default_error: str = response.reason or f"HTTP {response.status_code}"
error = default_error
if isinstance(response_json, dict):
# Shape 1: {"error": {"message": "Not authorized"}}
error_field = response_json.get("error")
if isinstance(error_field, dict):
response_error = error_field.get("message", "")
if response_error:
error = response_error
# Shape 2: {"error": "Invalid access token"}
elif isinstance(error_field, str):
error = error_field
# Shape 3: {"errors": [{"message": "..."}]}
# Used for validation errors. Only use as fallback if
# we didn't already find a more specific message above.
if error == default_error:
errors_list = response_json.get("errors")
if isinstance(errors_list, list) and errors_list:
first_error = errors_list[0]
if isinstance(first_error, dict):
msg = first_error.get("message", "")
if msg:
error = msg
raise OnyxError(
_error_code_for_status(response.status_code),
detail=error,
status_code_override=response.status_code,
)
next_url = self._parse_next_link(response.headers.get("Link", ""))
return response_json, next_url
def _parse_next_link(self, link_header: str) -> str | None:
"""Extract the 'next' URL from a Canvas Link header.
Only returns URLs whose host matches the configured Canvas base URL
to prevent leaking the bearer token to arbitrary hosts.
"""
expected_host = self._expected_host
for match in _NEXT_LINK_PATTERN.finditer(link_header):
url = match.group(1)
parsed_url = urlparse(url)
if parsed_url.hostname != expected_host:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
detail=(
"Canvas pagination returned an unexpected host "
f"({parsed_url.hostname}); expected {expected_host}"
),
)
if parsed_url.scheme != "https":
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
detail=(
"Canvas pagination link must use https, "
f"got {parsed_url.scheme!r}"
),
)
return url
return None
def _build_headers(self) -> dict[str, str]:
"""Return the Authorization header with the bearer token."""
return {"Authorization": f"Bearer {self._bearer_token}"}
def _build_url(self, endpoint: str) -> str:
"""Build a full Canvas API URL from an endpoint path.
Assumes endpoint is non-empty (e.g. ``"courses"``, ``"announcements"``).
Only called on a first request, endpoint must be set for first request.
Verify endpoint exists in case of future changes where endpoint might be optional.
Leading slashes are stripped to avoid double-slash in the result.
self.base_url is already normalized with no trailing slash.
"""
final_url = self.base_url
clean_endpoint = endpoint.lstrip("/")
if clean_endpoint:
final_url += "/" + clean_endpoint
return final_url

View File

@@ -1,74 +0,0 @@
from typing import Literal
from typing import TypeAlias
from pydantic import BaseModel
from onyx.connectors.models import ConnectorCheckpoint
class CanvasCourse(BaseModel):
id: int
name: str
course_code: str
created_at: str
workflow_state: str
class CanvasPage(BaseModel):
page_id: int
url: str
title: str
body: str | None = None
created_at: str
updated_at: str
course_id: int
class CanvasAssignment(BaseModel):
id: int
name: str
description: str | None = None
html_url: str
course_id: int
created_at: str
updated_at: str
due_at: str | None = None
class CanvasAnnouncement(BaseModel):
id: int
title: str
message: str | None = None
html_url: str
posted_at: str | None = None
course_id: int
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
class CanvasConnectorCheckpoint(ConnectorCheckpoint):
"""Checkpoint state for resumable Canvas indexing.
Fields:
course_ids: Materialized list of course IDs to process.
current_course_index: Index into course_ids for current course.
stage: Which item type we're processing for the current course.
next_url: Pagination cursor within the current stage. None means
start from the first page; a URL means resume from that page.
Invariant:
If current_course_index is incremented, stage must be reset to
"pages" and next_url must be reset to None.
"""
course_ids: list[int] = []
current_course_index: int = 0
stage: CanvasStage = "pages"
next_url: str | None = None
def advance_course(self) -> None:
"""Move to the next course and reset within-course state."""
self.current_course_index += 1
self.stage = "pages"
self.next_url = None

View File

@@ -123,7 +123,7 @@ class OnyxConfluence:
self.shared_base_kwargs: dict[str, str | int | bool] = {
"api_version": "cloud" if is_cloud else "latest",
"backoff_and_retry": False,
"backoff_and_retry": True,
"cloud": is_cloud,
}
if timeout:
@@ -456,7 +456,7 @@ class OnyxConfluence:
return attr(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
)

View File

@@ -363,7 +363,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except requests.HTTPError as e:
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
)
@@ -384,7 +384,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return cast(F, wrapped_call)
def _handle_http_error(e: requests.HTTPError, attempt: int, max_retries: int) -> int:
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
@@ -408,17 +408,6 @@ def _handle_http_error(e: requests.HTTPError, attempt: int, max_retries: int) ->
raise e
if e.response.status_code >= 500:
if attempt >= max_retries - 1:
raise e
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
logger.warning(
f"Server error {e.response.status_code}. "
f"Retrying in {delay} seconds (attempt {attempt + 1})..."
)
return math.ceil(time.monotonic() + delay)
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()

View File

@@ -10,7 +10,6 @@ from datetime import timedelta
from datetime import timezone
from typing import Any
import requests
from jira import JIRA
from jira.exceptions import JIRAError
from jira.resources import Issue
@@ -240,53 +239,29 @@ def enhanced_search_ids(
)
def _bulk_fetch_request(
jira_client: JIRA, issue_ids: list[str], fields: str | None
) -> list[dict[str, Any]]:
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO: move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
# Prepare the payload according to Jira API v3 specification
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
# Only restrict fields if specified, might want to explicitly do this in the future
# to avoid reading unnecessary data
payload["fields"] = fields.split(",") if fields else ["*all"]
resp = jira_client._session.post(bulk_fetch_path, json=payload)
return resp.json()["issues"]
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO(evan): move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
try:
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
except requests.exceptions.JSONDecodeError:
if len(issue_ids) <= 1:
logger.exception(
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
f"be decoded as JSON (response too large or truncated)."
)
raise
mid = len(issue_ids) // 2
logger.warning(
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
)
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
return left + right
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
except Exception as e:
logger.error(f"Error fetching issues: {e}")
raise
raise e
return [
Issue(jira_client._options, jira_client._session, raw=issue)
for issue in raw_issues
for issue in response["issues"]
]

View File

@@ -53,7 +53,7 @@ class NotionPage(BaseModel):
id: str
created_time: str
last_edited_time: str
in_trash: bool
archived: bool
properties: dict[str, Any]
url: str
@@ -63,13 +63,6 @@ class NotionPage(BaseModel):
)
class NotionDataSource(BaseModel):
"""Represents a Notion Data Source within a database."""
id: str
name: str = ""
class NotionBlock(BaseModel):
"""Represents a Notion Block object"""
@@ -114,7 +107,7 @@ class NotionConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self.headers = {
"Content-Type": "application/json",
"Notion-Version": "2026-03-11",
"Notion-Version": "2022-06-28",
}
self.indexed_pages: set[str] = set()
self.root_page_id = root_page_id
@@ -134,9 +127,6 @@ class NotionConnector(LoadConnector, PollConnector):
# Maps child page IDs to their containing page ID (discovered in _read_blocks).
# Used to resolve block_id parent types to the actual containing page.
self._child_page_parent_map: dict[str, str] = {}
# Maps data_source_id -> database_id (populated in _read_pages_from_database).
# Used to resolve data_source_id parent types back to the database.
self._data_source_to_database_map: dict[str, str] = {}
@classmethod
@override
@@ -237,11 +227,7 @@ class NotionConnector(LoadConnector, PollConnector):
@retry(tries=3, delay=1, backoff=2)
def _fetch_database_as_page(self, database_id: str) -> NotionPage:
"""Attempt to fetch a database as a page.
Note: As of API 2025-09-03, database objects no longer include
`properties` (schema moved to individual data sources).
"""
"""Attempt to fetch a database as a page."""
logger.debug(f"Fetching database for ID '{database_id}' as a page")
database_url = f"https://api.notion.com/v1/databases/{database_id}"
res = rl_requests.get(
@@ -260,52 +246,18 @@ class NotionConnector(LoadConnector, PollConnector):
database_name[0].get("text", {}).get("content") if database_name else None
)
db_data.setdefault("properties", {})
return NotionPage(**db_data, database_name=database_name)
@retry(tries=3, delay=1, backoff=2)
def _fetch_data_sources_for_database(
self, database_id: str
) -> list[NotionDataSource]:
"""Fetch the list of data sources for a database."""
logger.debug(f"Fetching data sources for database '{database_id}'")
res = rl_requests.get(
f"https://api.notion.com/v1/databases/{database_id}",
headers=self.headers,
timeout=_NOTION_CALL_TIMEOUT,
)
try:
res.raise_for_status()
except Exception as e:
if res.status_code in (403, 404):
logger.error(
f"Unable to access database with ID '{database_id}'. "
f"This is likely due to the database not being shared "
f"with the Onyx integration. Exact exception:\n{e}"
)
return []
logger.exception(f"Error fetching database - {res.json()}")
raise e
db_data = res.json()
data_sources = db_data.get("data_sources", [])
return [
NotionDataSource(id=ds["id"], name=ds.get("name", ""))
for ds in data_sources
if ds.get("id")
]
@retry(tries=3, delay=1, backoff=2)
def _fetch_data_source(
self, data_source_id: str, cursor: str | None = None
def _fetch_database(
self, database_id: str, cursor: str | None = None
) -> dict[str, Any]:
"""Query a data source via POST /v1/data_sources/{id}/query."""
logger.debug(f"Querying data source '{data_source_id}'")
url = f"https://api.notion.com/v1/data_sources/{data_source_id}/query"
"""Fetch a database from it's ID via the Notion API."""
logger.debug(f"Fetching database for ID '{database_id}'")
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
body = None if not cursor else {"start_cursor": cursor}
res = rl_requests.post(
url,
block_url,
headers=self.headers,
json=body,
timeout=_NOTION_CALL_TIMEOUT,
@@ -313,14 +265,25 @@ class NotionConnector(LoadConnector, PollConnector):
try:
res.raise_for_status()
except Exception as e:
if res.status_code in (403, 404):
json_data = res.json()
code = json_data.get("code")
# Sep 3 2025 backend changed the error message for this case
# TODO: it is also now possible for there to be multiple data sources per database; at present we
# just don't handle that. We will need to upgrade the API to the current version + query the
# new data sources endpoint to handle that case correctly.
if code == "object_not_found" or (
code == "validation_error"
and "does not contain any data sources" in json_data.get("message", "")
):
# this happens when a database is not shared with the integration
# in this case, we should just ignore the database
logger.error(
f"Unable to access data source with ID '{data_source_id}'. "
f"This is likely due to it not being shared "
f"Unable to access database with ID '{database_id}'. "
f"This is likely due to the database not being shared "
f"with the Onyx integration. Exact exception:\n{e}"
)
return {"results": [], "next_cursor": None}
logger.exception(f"Error querying data source - {res.json()}")
logger.exception(f"Error fetching database - {res.json()}")
raise e
return res.json()
@@ -385,9 +348,8 @@ class NotionConnector(LoadConnector, PollConnector):
# Fallback to workspace if we don't know the parent
return self.workspace_id
elif parent_type == "data_source_id":
ds_id = parent.get("data_source_id")
if ds_id:
return self._data_source_to_database_map.get(ds_id, self.workspace_id)
# Newer Notion API may use data_source_id for databases
return parent.get("database_id") or parent.get("data_source_id")
elif parent_type in ["page_id", "database_id"]:
return parent.get(parent_type)
@@ -535,32 +497,18 @@ class NotionConnector(LoadConnector, PollConnector):
if db_node:
hierarchy_nodes.append(db_node)
# Discover all data sources under this database, then query each one.
# Even legacy single-source databases have one entry in the array.
data_sources = self._fetch_data_sources_for_database(database_id)
if not data_sources:
logger.warning(
f"Database '{database_id}' returned zero data sources — "
f"no pages will be indexed from this database."
)
for ds in data_sources:
self._data_source_to_database_map[ds.id] = database_id
cursor = None
while True:
data = self._fetch_data_source(ds.id, cursor)
cursor = None
while True:
data = self._fetch_database(database_id, cursor)
for result in data["results"]:
obj_id = result["id"]
obj_type = result["object"]
text = self._properties_to_str(result.get("properties", {}))
if text:
result_blocks.append(
NotionBlock(id=obj_id, text=text, prefix="\n")
)
if not self.recursive_index_enabled:
continue
for result in data["results"]:
obj_id = result["id"]
obj_type = result["object"]
text = self._properties_to_str(result.get("properties", {}))
if text:
result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n"))
if self.recursive_index_enabled:
if obj_type == "page":
logger.debug(
f"Found page with ID '{obj_id}' in database '{database_id}'"
@@ -570,6 +518,7 @@ class NotionConnector(LoadConnector, PollConnector):
logger.debug(
f"Found database with ID '{obj_id}' in database '{database_id}'"
)
# Get nested database name from properties if available
nested_db_title = result.get("title", [])
nested_db_name = None
if nested_db_title and len(nested_db_title) > 0:
@@ -584,10 +533,10 @@ class NotionConnector(LoadConnector, PollConnector):
result_pages.extend(nested_output.child_page_ids)
hierarchy_nodes.extend(nested_output.hierarchy_nodes)
if data["next_cursor"] is None:
break
if data["next_cursor"] is None:
break
cursor = data["next_cursor"]
cursor = data["next_cursor"]
return BlockReadOutput(
blocks=result_blocks,
@@ -858,55 +807,36 @@ class NotionConnector(LoadConnector, PollConnector):
def _yield_database_hierarchy_nodes(
self,
) -> Generator[HierarchyNode | Document, None, None]:
"""Search for all data sources and yield hierarchy nodes for their parent databases.
"""Search for all databases and yield hierarchy nodes for each.
This must be called BEFORE page indexing so that database hierarchy nodes
exist when pages inside databases reference them as parents.
With the new API, search returns data source objects instead of databases.
Multiple data sources can share the same parent database, so we use
database_id as the hierarchy node key and deduplicate via
_maybe_yield_hierarchy_node.
"""
query_dict: dict[str, Any] = {
"filter": {"property": "object", "value": "data_source"},
"filter": {"property": "object", "value": "database"},
"page_size": _NOTION_PAGE_SIZE,
}
pages_seen = 0
while pages_seen < _MAX_PAGES:
db_res = self._search_notion(query_dict)
for ds in db_res.results:
# Extract the parent database_id from the data source's parent
ds_parent = ds.get("parent", {})
db_id = ds_parent.get("database_id")
if not db_id:
continue
# Populate the mapping so _get_parent_raw_id can resolve later
ds_id = ds.get("id")
if not ds_id:
continue
self._data_source_to_database_map[ds_id] = db_id
# Fetch the database to get its actual name and parent
try:
db_page = self._fetch_database_as_page(db_id)
db_name = db_page.database_name or f"Database {db_id}"
parent_raw_id = self._get_parent_raw_id(db_page.parent)
db_url = (
db_page.url or f"https://notion.so/{db_id.replace('-', '')}"
)
except requests.exceptions.RequestException as e:
logger.warning(
f"Could not fetch database '{db_id}', "
f"defaulting to workspace root. Error: {e}"
)
for db in db_res.results:
db_id = db["id"]
# Extract title from the title array
title_arr = db.get("title", [])
db_name = None
if title_arr:
db_name = " ".join(
t.get("plain_text", "") for t in title_arr
).strip()
if not db_name:
db_name = f"Database {db_id}"
parent_raw_id = self.workspace_id
db_url = f"https://notion.so/{db_id.replace('-', '')}"
# _maybe_yield_hierarchy_node deduplicates by raw_node_id,
# so multiple data sources under one database produce one node.
# Get parent using existing helper
parent_raw_id = self._get_parent_raw_id(db.get("parent"))
# Notion URLs omit dashes from UUIDs
db_url = db.get("url") or f"https://notion.so/{db_id.replace('-', '')}"
node = self._maybe_yield_hierarchy_node(
raw_node_id=db_id,
raw_parent_id=parent_raw_id or self.workspace_id,

View File

@@ -1,6 +1,5 @@
import base64
import copy
import fnmatch
import html
import io
import os
@@ -85,44 +84,6 @@ SHARED_DOCUMENTS_MAP_REVERSE = {v: k for k, v in SHARED_DOCUMENTS_MAP.items()}
ASPX_EXTENSION = ".aspx"
def _is_site_excluded(site_url: str, excluded_site_patterns: list[str]) -> bool:
"""Check if a site URL matches any of the exclusion glob patterns."""
for pattern in excluded_site_patterns:
if fnmatch.fnmatch(site_url, pattern) or fnmatch.fnmatch(
site_url.rstrip("/"), pattern.rstrip("/")
):
return True
return False
def _is_path_excluded(item_path: str, excluded_path_patterns: list[str]) -> bool:
"""Check if a drive item path matches any of the exclusion glob patterns.
item_path is the relative path within a drive, e.g. "Engineering/API/report.docx".
Matches are attempted against the full path and the filename alone so that
patterns like "*.tmp" match files at any depth.
"""
filename = item_path.rsplit("/", 1)[-1] if "/" in item_path else item_path
for pattern in excluded_path_patterns:
if fnmatch.fnmatch(item_path, pattern) or fnmatch.fnmatch(filename, pattern):
return True
return False
def _build_item_relative_path(parent_reference_path: str | None, item_name: str) -> str:
"""Build the relative path of a drive item from its parentReference.path and name.
Example: parentReference.path="/drives/abc/root:/Eng/API", name="report.docx"
=> "Eng/API/report.docx"
"""
if parent_reference_path and "root:/" in parent_reference_path:
folder = unquote(parent_reference_path.split("root:/", 1)[1])
if folder:
return f"{folder}/{item_name}"
return item_name
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
DEFAULT_SHAREPOINT_DOMAIN_SUFFIX = "sharepoint.com"
@@ -517,7 +478,6 @@ def _convert_driveitem_to_document_with_permissions(
include_permissions: bool = False,
parent_hierarchy_raw_node_id: str | None = None,
access_token: str | None = None,
treat_sharing_link_as_public: bool = False,
) -> Document | ConnectorFailure | None:
if not driveitem.name or not driveitem.id:
@@ -650,7 +610,6 @@ def _convert_driveitem_to_document_with_permissions(
drive_item=sdk_item,
drive_name=drive_name,
add_prefix=True,
treat_sharing_link_as_public=treat_sharing_link_as_public,
)
else:
external_access = ExternalAccess.empty()
@@ -685,7 +644,6 @@ def _convert_sitepage_to_document(
graph_client: GraphClient,
include_permissions: bool = False,
parent_hierarchy_raw_node_id: str | None = None,
treat_sharing_link_as_public: bool = False,
) -> Document:
"""Convert a SharePoint site page to a Document object."""
# Extract text content from the site page
@@ -815,7 +773,6 @@ def _convert_sitepage_to_document(
graph_client=graph_client,
site_page=site_page,
add_prefix=True,
treat_sharing_link_as_public=treat_sharing_link_as_public,
)
else:
external_access = ExternalAccess.empty()
@@ -846,7 +803,6 @@ def _convert_driveitem_to_slim_document(
ctx: ClientContext,
graph_client: GraphClient,
parent_hierarchy_raw_node_id: str | None = None,
treat_sharing_link_as_public: bool = False,
) -> SlimDocument:
if driveitem.id is None:
raise ValueError("DriveItem ID is required")
@@ -857,7 +813,6 @@ def _convert_driveitem_to_slim_document(
graph_client=graph_client,
drive_item=sdk_item,
drive_name=drive_name,
treat_sharing_link_as_public=treat_sharing_link_as_public,
)
return SlimDocument(
@@ -872,7 +827,6 @@ def _convert_sitepage_to_slim_document(
ctx: ClientContext | None,
graph_client: GraphClient,
parent_hierarchy_raw_node_id: str | None = None,
treat_sharing_link_as_public: bool = False,
) -> SlimDocument:
"""Convert a SharePoint site page to a SlimDocument object."""
if site_page.get("id") is None:
@@ -882,7 +836,6 @@ def _convert_sitepage_to_slim_document(
ctx=ctx,
graph_client=graph_client,
site_page=site_page,
treat_sharing_link_as_public=treat_sharing_link_as_public,
)
id = site_page.get("id")
if id is None:
@@ -902,20 +855,14 @@ class SharepointConnector(
self,
batch_size: int = INDEX_BATCH_SIZE,
sites: list[str] = [],
excluded_sites: list[str] = [],
excluded_paths: list[str] = [],
include_site_pages: bool = True,
include_site_documents: bool = True,
treat_sharing_link_as_public: bool = False,
authority_host: str = DEFAULT_AUTHORITY_HOST,
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
sharepoint_domain_suffix: str = DEFAULT_SHAREPOINT_DOMAIN_SUFFIX,
) -> None:
self.batch_size = batch_size
self.sites = list(sites)
self.excluded_sites = [s for p in excluded_sites if (s := p.strip())]
self.excluded_paths = [s for p in excluded_paths if (s := p.strip())]
self.treat_sharing_link_as_public = treat_sharing_link_as_public
self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info(
sites
)
@@ -1286,29 +1233,6 @@ class SharepointConnector(
break
sites = sites._get_next().execute_query()
def _is_driveitem_excluded(self, driveitem: DriveItemData) -> bool:
"""Check if a drive item should be excluded based on excluded_paths patterns."""
if not self.excluded_paths:
return False
relative_path = _build_item_relative_path(
driveitem.parent_reference_path, driveitem.name
)
return _is_path_excluded(relative_path, self.excluded_paths)
def _filter_excluded_sites(
self, site_descriptors: list[SiteDescriptor]
) -> list[SiteDescriptor]:
"""Remove sites matching any excluded_sites glob pattern."""
if not self.excluded_sites:
return site_descriptors
result = []
for sd in site_descriptors:
if _is_site_excluded(sd.url, self.excluded_sites):
logger.info(f"Excluding site by denylist: {sd.url}")
continue
result.append(sd)
return result
def fetch_sites(self) -> list[SiteDescriptor]:
sites = self.graph_client.sites.get_all_sites().execute_query()
@@ -1325,7 +1249,7 @@ class SharepointConnector(
for site in self._handle_paginated_sites(sites)
if "-my.sharepoint" not in site.web_url
]
return self._filter_excluded_sites(site_descriptors)
return site_descriptors
def _fetch_site_pages(
self,
@@ -1766,9 +1690,7 @@ class SharepointConnector(
checkpoint.seen_document_ids.clear()
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
site_descriptors = self._filter_excluded_sites(
self.site_descriptors or self.fetch_sites()
)
site_descriptors = self.site_descriptors or self.fetch_sites()
# Create a temporary checkpoint for hierarchy node tracking
temp_checkpoint = SharepointConnectorCheckpoint(has_more=True)
@@ -1788,10 +1710,6 @@ class SharepointConnector(
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
site_descriptor=site_descriptor
):
if self._is_driveitem_excluded(driveitem):
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
continue
if drive_web_url:
doc_batch.extend(
self._yield_drive_hierarchy_node(
@@ -1829,7 +1747,6 @@ class SharepointConnector(
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=parent_hierarchy_url,
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
)
)
except Exception as e:
@@ -1853,7 +1770,6 @@ class SharepointConnector(
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=site_descriptor.url,
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
)
)
if len(doc_batch) >= SLIM_BATCH_SIZE:
@@ -2127,9 +2043,7 @@ class SharepointConnector(
and not checkpoint.process_site_pages
):
logger.info("Initializing SharePoint sites for processing")
site_descs = self._filter_excluded_sites(
self.site_descriptors or self.fetch_sites()
)
site_descs = self.site_descriptors or self.fetch_sites()
checkpoint.cached_site_descriptors = deque(site_descs)
if not checkpoint.cached_site_descriptors:
@@ -2350,10 +2264,6 @@ class SharepointConnector(
for driveitem in driveitems:
item_count += 1
if self._is_driveitem_excluded(driveitem):
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
continue
if driveitem.id and driveitem.id in checkpoint.seen_document_ids:
logger.debug(
f"Skipping duplicate document {driveitem.id} ({driveitem.name})"
@@ -2408,7 +2318,6 @@ class SharepointConnector(
parent_hierarchy_raw_node_id=parent_hierarchy_url,
graph_api_base=self.graph_api_base,
access_token=access_token,
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
)
if isinstance(doc_or_failure, Document):
@@ -2489,7 +2398,6 @@ class SharepointConnector(
include_permissions=include_permissions,
# Site pages have the site as their parent
parent_hierarchy_raw_node_id=site_descriptor.url,
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
)
)
logger.info(

View File

@@ -17,7 +17,6 @@ def get_sharepoint_external_access(
drive_name: str | None = None,
site_page: dict[str, Any] | None = None,
add_prefix: bool = False,
treat_sharing_link_as_public: bool = False,
) -> ExternalAccess:
if drive_item and drive_item.id is None:
raise ValueError("DriveItem ID is required")
@@ -35,13 +34,7 @@ def get_sharepoint_external_access(
)
external_access = get_external_access_func(
ctx,
graph_client,
drive_name,
drive_item,
site_page,
add_prefix,
treat_sharing_link_as_public,
ctx, graph_client, drive_name, drive_item, site_page, add_prefix
)
return external_access

View File

@@ -401,16 +401,3 @@ class SavedSearchDocWithContent(SavedSearchDoc):
section in addition to the match_highlights."""
content: str
class PersonaSearchInfo(BaseModel):
"""Snapshot of persona data needed by the search pipeline.
Extracted from the ORM Persona before the DB session is released so that
SearchTool and search_pipeline never lazy-load relationships post-commit.
"""
document_set_names: list[str]
search_start_date: datetime | None
attached_document_ids: list[str]
hierarchy_node_ids: list[int]

View File

@@ -9,12 +9,12 @@ from onyx.context.search.models import ChunkSearchRequest
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import PersonaSearchInfo
from onyx.context.search.preprocessing.access_filters import (
build_access_filters_for_user,
)
from onyx.context.search.retrieval.search_runner import search_chunks
from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
@@ -247,8 +247,8 @@ def search_pipeline(
document_index: DocumentIndex,
# Used for ACLs and federated search, anonymous users only see public docs
user: User,
# Pre-extracted persona search configuration (None when no persona)
persona_search_info: PersonaSearchInfo | None,
# Used for default filters and settings
persona: Persona | None,
db_session: Session | None = None,
auto_detect_filters: bool = False,
llm: LLM | None = None,
@@ -263,18 +263,24 @@ def search_pipeline(
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
) -> list[InferenceChunk]:
persona_document_sets: list[str] | None = (
persona_search_info.document_set_names if persona_search_info else None
[persona_document_set.name for persona_document_set in persona.document_sets]
if persona
else None
)
persona_time_cutoff: datetime | None = (
persona_search_info.search_start_date if persona_search_info else None
persona.search_start_date if persona else None
)
# Extract assistant knowledge filters from persona
attached_document_ids: list[str] | None = (
persona_search_info.attached_document_ids or None
if persona_search_info
[doc.id for doc in persona.attached_documents]
if persona and persona.attached_documents
else None
)
hierarchy_node_ids: list[int] | None = (
persona_search_info.hierarchy_node_ids or None if persona_search_info else None
[node.id for node in persona.hierarchy_nodes]
if persona and persona.hierarchy_nodes
else None
)
filters = _build_index_filters(

View File

@@ -14,10 +14,6 @@ from onyx.context.search.utils import get_query_embedding
from onyx.context.search.utils import inference_section_from_chunks
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.interfaces_new import DocumentIndex as NewDocumentIndex
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchOldDocumentIndex,
)
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.federated_connectors.federated_retrieval import (
get_federated_retrieval_functions,
@@ -53,7 +49,7 @@ def combine_retrieval_results(
return sorted_chunks
def _embed_and_hybrid_search(
def _embed_and_search(
query_request: ChunkIndexRequest,
document_index: DocumentIndex,
db_session: Session | None = None,
@@ -85,17 +81,6 @@ def _embed_and_hybrid_search(
return top_chunks
def _keyword_search(
query_request: ChunkIndexRequest,
document_index: NewDocumentIndex,
) -> list[InferenceChunk]:
return document_index.keyword_retrieval(
query=query_request.query,
filters=query_request.filters,
num_to_retrieve=query_request.limit or NUM_RETURNED_HITS,
)
def search_chunks(
query_request: ChunkIndexRequest,
user_id: UUID | None,
@@ -143,38 +128,21 @@ def search_chunks(
)
if normal_search_enabled:
if (
query_request.hybrid_alpha is not None
and query_request.hybrid_alpha == 0.0
and isinstance(document_index, OpenSearchOldDocumentIndex)
):
# If hybrid alpha is explicitly set to keyword only, do pure keyword
# search without generating an embedding. This is currently only
# supported with OpenSearchDocumentIndex.
opensearch_new_document_index: NewDocumentIndex = document_index._real_index
run_queries.append(
(
lambda: _keyword_search(
query_request, opensearch_new_document_index
),
(),
)
)
else:
run_queries.append(
(
_embed_and_hybrid_search,
(query_request, document_index, db_session, embedding_model),
)
run_queries.append(
(
_embed_and_search,
(query_request, document_index, db_session, embedding_model),
)
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
top_chunks = combine_retrieval_results(parallel_search_results)
if not top_chunks:
logger.debug(
f"Search returned no results for query: {query_request.query} with filters: {query_request.filters}."
f"Hybrid search returned no results for query: {query_request.query}with filters: {query_request.filters}"
)
return []
return top_chunks

View File

@@ -16,7 +16,6 @@ from sqlalchemy import Row
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
@@ -29,7 +28,6 @@ from onyx.db.models import ChatMessage
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ChatSession
from onyx.db.models import ChatSessionSharedStatus
from onyx.db.models import Persona
from onyx.db.models import SearchDoc as DBSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
@@ -55,22 +53,9 @@ def get_chat_session_by_id(
db_session: Session,
include_deleted: bool = False,
is_shared: bool = False,
eager_load_persona: bool = False,
) -> ChatSession:
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
if eager_load_persona:
stmt = stmt.options(
joinedload(ChatSession.persona).options(
selectinload(Persona.tools),
selectinload(Persona.user_files),
selectinload(Persona.document_sets),
selectinload(Persona.attached_documents),
selectinload(Persona.hierarchy_nodes),
),
joinedload(ChatSession.project),
)
if is_shared:
stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC)
else:

View File

@@ -750,31 +750,3 @@ def resync_cc_pair(
)
db_session.commit()
# ── Metrics query helpers ──────────────────────────────────────────────
def get_connector_health_for_metrics(
db_session: Session,
) -> list: # Returns list of Row tuples
"""Return connector health data for Prometheus metrics.
Each row is (cc_pair_id, status, in_repeated_error_state,
last_successful_index_time, name, source).
"""
return (
db_session.query(
ConnectorCredentialPair.id,
ConnectorCredentialPair.status,
ConnectorCredentialPair.in_repeated_error_state,
ConnectorCredentialPair.last_successful_index_time,
ConnectorCredentialPair.name,
Connector.source,
)
.join(
Connector,
ConnectorCredentialPair.connector_id == Connector.id,
)
.all()
)

View File

@@ -1,31 +1,4 @@
from __future__ import annotations
from enum import Enum as PyEnum
from typing import ClassVar
class AccountType(str, PyEnum):
"""
What kind of account this is — determines whether the user
enters the group-based permission system.
STANDARD + SERVICE_ACCOUNT → participate in group system
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
"""
STANDARD = "standard"
BOT = "bot"
EXT_PERM_USER = "ext_perm_user"
SERVICE_ACCOUNT = "service_account"
ANONYMOUS = "anonymous"
class GrantSource(str, PyEnum):
"""How a permission grant was created."""
USER = "user"
SCIM = "scim"
SYSTEM = "system"
class IndexingStatus(str, PyEnum):
@@ -341,54 +314,3 @@ class HookPoint(str, PyEnum):
class HookFailStrategy(str, PyEnum):
HARD = "hard" # exception propagates, pipeline aborts
SOFT = "soft" # log error, return original input, pipeline continues
class Permission(str, PyEnum):
"""
Permission tokens for group-based authorization.
19 tokens total. full_admin_panel_access is an override —
if present, any permission check passes.
"""
# Basic (auto-granted to every new group)
BASIC_ACCESS = "basic"
# Read tokens — implied only, never granted directly
READ_CONNECTORS = "read:connectors"
READ_DOCUMENT_SETS = "read:document_sets"
READ_AGENTS = "read:agents"
READ_USERS = "read:users"
# Add / Manage pairs
ADD_AGENTS = "add:agents"
MANAGE_AGENTS = "manage:agents"
MANAGE_DOCUMENT_SETS = "manage:document_sets"
ADD_CONNECTORS = "add:connectors"
MANAGE_CONNECTORS = "manage:connectors"
MANAGE_LLMS = "manage:llms"
# Toggle tokens
READ_AGENT_ANALYTICS = "read:agent_analytics"
MANAGE_ACTIONS = "manage:actions"
READ_QUERY_HISTORY = "read:query_history"
MANAGE_USER_GROUPS = "manage:user_groups"
CREATE_USER_API_KEYS = "create:user_api_keys"
CREATE_SERVICE_ACCOUNT_API_KEYS = "create:service_account_api_keys"
CREATE_SLACK_DISCORD_BOTS = "create:slack_discord_bots"
# Override — any permission check passes
FULL_ADMIN_PANEL_ACCESS = "admin"
# Permissions that are implied by other grants and must never be stored
# directly in the permission_grant table.
IMPLIED: ClassVar[frozenset[Permission]]
Permission.IMPLIED = frozenset(
{
Permission.READ_CONNECTORS,
Permission.READ_DOCUMENT_SETS,
Permission.READ_AGENTS,
Permission.READ_USERS,
}
)

View File

@@ -75,7 +75,6 @@ def create_hook__no_commit(
fail_strategy: HookFailStrategy,
timeout_seconds: float,
is_active: bool = False,
is_reachable: bool | None = None,
creator_id: UUID | None = None,
) -> Hook:
"""Create a new hook for the given hook point.
@@ -101,7 +100,6 @@ def create_hook__no_commit(
fail_strategy=fail_strategy,
timeout_seconds=timeout_seconds,
is_active=is_active,
is_reachable=is_reachable,
creator_id=creator_id,
)
# Use a savepoint so that a failed insert only rolls back this operation,

View File

@@ -2,8 +2,6 @@ from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import NamedTuple
from typing import TYPE_CHECKING
from typing import TypeVarTuple
from sqlalchemy import and_
@@ -30,9 +28,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
if TYPE_CHECKING:
from onyx.configs.constants import DocumentSource
# from sqlalchemy.sql.selectable import Select
# Comment out unused imports that cause mypy errors
@@ -977,106 +972,3 @@ def get_index_attempt_errors_for_cc_pair(
stmt = stmt.offset(page * page_size).limit(page_size)
return list(db_session.scalars(stmt).all())
# ── Metrics query helpers ──────────────────────────────────────────────
class ActiveIndexAttemptMetric(NamedTuple):
"""Row returned by get_active_index_attempts_for_metrics."""
status: IndexingStatus
source: "DocumentSource"
cc_pair_id: int
cc_pair_name: str | None
attempt_count: int
def get_active_index_attempts_for_metrics(
db_session: Session,
) -> list[ActiveIndexAttemptMetric]:
"""Return non-terminal index attempts grouped by status, source, and connector.
Each row is (status, source, cc_pair_id, cc_pair_name, attempt_count).
"""
from onyx.db.models import Connector
terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
rows = (
db_session.query(
IndexAttempt.status,
Connector.source,
ConnectorCredentialPair.id,
ConnectorCredentialPair.name,
func.count(),
)
.join(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.join(
Connector,
ConnectorCredentialPair.connector_id == Connector.id,
)
.filter(IndexAttempt.status.notin_(terminal_statuses))
.group_by(
IndexAttempt.status,
Connector.source,
ConnectorCredentialPair.id,
ConnectorCredentialPair.name,
)
.all()
)
return [ActiveIndexAttemptMetric(*row) for row in rows]
def get_failed_attempt_counts_by_cc_pair(
db_session: Session,
since: datetime | None = None,
) -> dict[int, int]:
"""Return {cc_pair_id: failed_attempt_count} for all connectors.
When ``since`` is provided, only attempts created after that timestamp
are counted. Defaults to the last 90 days to avoid unbounded historical
aggregation.
"""
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=90)
rows = (
db_session.query(
IndexAttempt.connector_credential_pair_id,
func.count(),
)
.filter(IndexAttempt.status == IndexingStatus.FAILED)
.filter(IndexAttempt.time_created >= since)
.group_by(IndexAttempt.connector_credential_pair_id)
.all()
)
return {cc_id: count for cc_id, count in rows}
def get_docs_indexed_by_cc_pair(
db_session: Session,
since: datetime | None = None,
) -> dict[int, int]:
"""Return {cc_pair_id: total_new_docs_indexed} across successful attempts.
Only counts attempts with status SUCCESS to avoid inflating counts with
partial results from failed attempts. When ``since`` is provided, only
attempts created after that timestamp are included.
"""
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=90)
query = (
db_session.query(
IndexAttempt.connector_credential_pair_id,
func.sum(func.coalesce(IndexAttempt.new_docs_indexed, 0)),
)
.filter(IndexAttempt.status == IndexingStatus.SUCCESS)
.filter(IndexAttempt.time_created >= since)
.group_by(IndexAttempt.connector_credential_pair_id)
)
rows = query.all()
return {cc_id: int(total or 0) for cc_id, total in rows}

View File

@@ -48,7 +48,6 @@ from sqlalchemy.types import LargeBinary
from sqlalchemy.types import TypeDecorator
from sqlalchemy import PrimaryKeyConstraint
from onyx.db.enums import AccountType
from onyx.auth.schemas import UserRole
from onyx.configs.constants import (
ANONYMOUS_USER_UUID,
@@ -79,8 +78,6 @@ from onyx.db.enums import (
MCPAuthenticationPerformer,
MCPTransport,
MCPServerStatus,
Permission,
GrantSource,
LLMModelFlowType,
ThemePreference,
DefaultAppMode,
@@ -305,9 +302,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
)
account_type: Mapped[AccountType | None] = mapped_column(
Enum(AccountType, native_enum=False), nullable=True
)
"""
Preferences probably should be in a separate table at some point, but for now
@@ -2651,15 +2645,6 @@ class ChatMessage(Base):
nullable=True,
)
# For multi-model turns: the user message points to which assistant response
# was selected as the preferred one to continue the conversation with.
preferred_response_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True
)
# The display name of the model that generated this assistant message
model_display_name: Mapped[str | None] = mapped_column(String, nullable=True)
# What does this message contain
reasoning_tokens: Mapped[str | None] = mapped_column(Text, nullable=True)
message: Mapped[str] = mapped_column(Text)
@@ -2727,12 +2712,6 @@ class ChatMessage(Base):
remote_side="ChatMessage.id",
)
preferred_response: Mapped["ChatMessage | None"] = relationship(
"ChatMessage",
foreign_keys=[preferred_response_id],
remote_side="ChatMessage.id",
)
# Chat messages only need to know their immediate tool call children
# If there are nested tool calls, they are stored in the tool_call_children relationship.
tool_calls: Mapped[list["ToolCall"] | None] = relationship(
@@ -3135,6 +3114,8 @@ class VoiceProvider(Base):
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
@@ -3990,8 +3971,6 @@ class SamlAccount(Base):
class User__UserGroup(Base):
__tablename__ = "user__user_group"
__table_args__ = (Index("ix_user__user_group_user_id", "user_id"),)
is_curator: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
user_group_id: Mapped[int] = mapped_column(
@@ -4002,48 +3981,6 @@ class User__UserGroup(Base):
)
class PermissionGrant(Base):
__tablename__ = "permission_grant"
__table_args__ = (
UniqueConstraint(
"group_id", "permission", name="uq_permission_grant_group_permission"
),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
)
permission: Mapped[Permission] = mapped_column(
Enum(Permission, native_enum=False), nullable=False
)
grant_source: Mapped[GrantSource] = mapped_column(
Enum(GrantSource, native_enum=False), nullable=False
)
granted_by: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="SET NULL"), nullable=True
)
granted_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
is_deleted: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False, server_default=text("false")
)
group: Mapped["UserGroup"] = relationship(
"UserGroup", back_populates="permission_grants"
)
@validates("permission")
def _validate_permission(self, _key: str, value: Permission) -> Permission:
if value in Permission.IMPLIED:
raise ValueError(
f"{value!r} is an implied permission and cannot be granted directly"
)
return value
class UserGroup__ConnectorCredentialPair(Base):
__tablename__ = "user_group__connector_credential_pair"
@@ -4138,8 +4075,6 @@ class UserGroup(Base):
is_up_for_deletion: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
# whether this is a default group (e.g. "Basic", "Admins") that cannot be deleted
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
# Last time a user updated this user group
time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column(
@@ -4183,9 +4118,6 @@ class UserGroup(Base):
accessible_mcp_servers: Mapped[list["MCPServer"]] = relationship(
"MCPServer", secondary="mcp_server__user_group", back_populates="user_groups"
)
permission_grants: Mapped[list["PermissionGrant"]] = relationship(
"PermissionGrant", back_populates="group", cascade="all, delete-orphan"
)
"""Tables related to Token Rate Limiting

View File

@@ -50,18 +50,8 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
def get_default_behavior_persona(
db_session: Session,
eager_load_for_tools: bool = False,
) -> Persona | None:
def get_default_behavior_persona(db_session: Session) -> Persona | None:
stmt = select(Persona).where(Persona.id == DEFAULT_PERSONA_ID)
if eager_load_for_tools:
stmt = stmt.options(
selectinload(Persona.tools),
selectinload(Persona.document_sets),
selectinload(Persona.attached_documents),
selectinload(Persona.hierarchy_nodes),
)
return db_session.scalars(stmt).first()

View File

@@ -17,30 +17,39 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
"""Fetch all voice providers."""
return list(
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
db_session.scalars(
select(VoiceProvider)
.where(VoiceProvider.deleted.is_(False))
.order_by(VoiceProvider.name)
).all()
)
def fetch_voice_provider_by_id(
db_session: Session, provider_id: int
db_session: Session, provider_id: int, include_deleted: bool = False
) -> VoiceProvider | None:
"""Fetch a voice provider by ID."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.id == provider_id)
)
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
if not include_deleted:
stmt = stmt.where(VoiceProvider.deleted.is_(False))
return db_session.scalar(stmt)
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default STT provider."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
select(VoiceProvider)
.where(VoiceProvider.is_default_stt.is_(True))
.where(VoiceProvider.deleted.is_(False))
)
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default TTS provider."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
select(VoiceProvider)
.where(VoiceProvider.is_default_tts.is_(True))
.where(VoiceProvider.deleted.is_(False))
)
@@ -49,7 +58,9 @@ def fetch_voice_provider_by_type(
) -> VoiceProvider | None:
"""Fetch a voice provider by type."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
select(VoiceProvider)
.where(VoiceProvider.provider_type == provider_type)
.where(VoiceProvider.deleted.is_(False))
)
@@ -108,10 +119,10 @@ def upsert_voice_provider(
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
"""Delete a voice provider by ID."""
"""Soft-delete a voice provider by ID."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider:
db_session.delete(provider)
provider.deleted = True
db_session.flush()

View File

@@ -5,6 +5,7 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
out against a nonexistent Vespa/OpenSearch instance.
"""
from collections.abc import Iterable
from typing import Any
from onyx.context.search.models import IndexFilters
@@ -66,7 +67,7 @@ class DisabledDocumentIndex(DocumentIndex):
# ------------------------------------------------------------------
def index(
self,
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
index_batch_params: IndexBatchParams, # noqa: ARG002
) -> set[DocumentInsertionRecord]:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)

View File

@@ -1,4 +1,5 @@
import abc
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any
@@ -206,7 +207,7 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[DocumentInsertionRecord]:
"""
@@ -226,8 +227,8 @@ class Indexable(abc.ABC):
it is done automatically outside of this code.
Parameters:
- chunks: Document chunks with all of the information needed for indexing to the document
index.
- chunks: Document chunks with all of the information needed for
indexing to the document index.
- tenant_id: The tenant id of the user whose chunks are being indexed
- large_chunks_enabled: Whether large chunks are enabled

View File

@@ -1,4 +1,5 @@
import abc
from collections.abc import Iterable
from typing import Self
from pydantic import BaseModel
@@ -209,10 +210,10 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
"""Indexes a list of document chunks into the document index.
"""Indexes an iterable of document chunks into the document index.
This is often a batch operation including chunks from multiple
documents.
@@ -381,47 +382,6 @@ class HybridCapable(abc.ABC):
"""
raise NotImplementedError
@abc.abstractmethod
def keyword_retrieval(
self,
query: str,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
"""Runs keyword-only search and returns a list of inference chunks.
Args:
query: User query.
filters: Filters for things like permissions, source type, time,
etc.
num_to_retrieve: Number of highest matching chunks to return.
Returns:
Score-ranked (highest first) list of highest matching chunks.
"""
raise NotImplementedError
@abc.abstractmethod
def semantic_retrieval(
self,
query_embedding: Embedding,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
"""Runs semantic-only search and returns a list of inference chunks.
Args:
query_embedding: Vector representation of the query. Must be of the
correct dimensionality for the primary index.
filters: Filters for things like permissions, source type, time,
etc.
num_to_retrieve: Number of highest matching chunks to return.
Returns:
Score-ranked (highest first) list of highest matching chunks.
"""
raise NotImplementedError
class RandomCapable(abc.ABC):
"""

View File

@@ -18,13 +18,10 @@ from onyx.configs.app_configs import OPENSEARCH_ADMIN_USERNAME
from onyx.configs.app_configs import OPENSEARCH_HOST
from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import OpenSearchSearchType
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.search import DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW
from onyx.server.metrics.opensearch_search import observe_opensearch_search
from onyx.server.metrics.opensearch_search import track_opensearch_search_in_progress
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
@@ -259,6 +256,7 @@ class OpenSearchClient(AbstractContextManager):
"""
return self._client.ping()
@log_function_time(print_only=True, debug_only=True)
def close(self) -> None:
"""Closes the client.
@@ -306,7 +304,6 @@ class OpenSearchIndexClient(OpenSearchClient):
verify_certs: bool = False,
ssl_show_warn: bool = False,
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
emit_metrics: bool = True,
):
super().__init__(
host=host,
@@ -318,7 +315,6 @@ class OpenSearchIndexClient(OpenSearchClient):
timeout=timeout,
)
self._index_name = index_name
self._emit_metrics = emit_metrics
logger.debug(
f"OpenSearch client created successfully for index {self._index_name}."
)
@@ -838,10 +834,7 @@ class OpenSearchIndexClient(OpenSearchClient):
@log_function_time(print_only=True, debug_only=True)
def search(
self,
body: dict[str, Any],
search_pipeline_id: str | None,
search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN,
self, body: dict[str, Any], search_pipeline_id: str | None
) -> list[SearchHit[DocumentChunkWithoutVectors]]:
"""Searches the index.
@@ -859,8 +852,6 @@ class OpenSearchIndexClient(OpenSearchClient):
documentation for more information on search request bodies.
search_pipeline_id: The ID of the search pipeline to use. If None,
the default search pipeline will be used.
search_type: Label for Prometheus metrics. Does not affect search
behavior.
Raises:
Exception: There was an error searching the index.
@@ -873,27 +864,21 @@ class OpenSearchIndexClient(OpenSearchClient):
)
result: dict[str, Any]
params = {"phase_took": "true"}
ctx = self._get_emit_metrics_context_manager(search_type)
t0 = time.perf_counter()
with ctx:
if search_pipeline_id:
result = self._client.search(
index=self._index_name,
search_pipeline=search_pipeline_id,
body=body,
params=params,
)
else:
result = self._client.search(
index=self._index_name, body=body, params=params
)
client_duration_s = time.perf_counter() - t0
if search_pipeline_id:
result = self._client.search(
index=self._index_name,
search_pipeline=search_pipeline_id,
body=body,
params=params,
)
else:
result = self._client.search(
index=self._index_name, body=body, params=params
)
hits, time_took, timed_out, phase_took, profile = (
self._get_hits_and_profile_from_search_result(result)
)
if self._emit_metrics:
observe_opensearch_search(search_type, client_duration_s, time_took)
self._log_search_result_perf(
time_took=time_took,
timed_out=timed_out,
@@ -929,11 +914,7 @@ class OpenSearchIndexClient(OpenSearchClient):
return search_hits
@log_function_time(print_only=True, debug_only=True)
def search_for_document_ids(
self,
body: dict[str, Any],
search_type: OpenSearchSearchType = OpenSearchSearchType.DOCUMENT_IDS,
) -> list[str]:
def search_for_document_ids(self, body: dict[str, Any]) -> list[str]:
"""Searches the index and returns only document chunk IDs.
In order to take advantage of the performance benefits of only returning
@@ -950,8 +931,6 @@ class OpenSearchIndexClient(OpenSearchClient):
documentation for more information on search request bodies.
TODO(andrei): Make this a more deep interface; callers shouldn't
need to know to set _source: False for example.
search_type: Label for Prometheus metrics. Does not affect search
behavior.
Raises:
Exception: There was an error searching the index.
@@ -969,19 +948,13 @@ class OpenSearchIndexClient(OpenSearchClient):
)
params = {"phase_took": "true"}
ctx = self._get_emit_metrics_context_manager(search_type)
t0 = time.perf_counter()
with ctx:
result: dict[str, Any] = self._client.search(
index=self._index_name, body=body, params=params
)
client_duration_s = time.perf_counter() - t0
result: dict[str, Any] = self._client.search(
index=self._index_name, body=body, params=params
)
hits, time_took, timed_out, phase_took, profile = (
self._get_hits_and_profile_from_search_result(result)
)
if self._emit_metrics:
observe_opensearch_search(search_type, client_duration_s, time_took)
self._log_search_result_perf(
time_took=time_took,
timed_out=timed_out,
@@ -1098,20 +1071,6 @@ class OpenSearchIndexClient(OpenSearchClient):
if raise_on_timeout:
raise RuntimeError(error_str)
def _get_emit_metrics_context_manager(
self, search_type: OpenSearchSearchType
) -> AbstractContextManager[None]:
"""
Returns a context manager that tracks in-flight OpenSearch searches via
a Gauge if emit_metrics is True, otherwise returns a null context
manager.
"""
return (
track_opensearch_search_in_progress(search_type)
if self._emit_metrics
else nullcontext()
)
def wait_for_opensearch_with_timeout(
wait_interval_s: int = 5,

View File

@@ -53,18 +53,6 @@ DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES = int(
EF_SEARCH = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
class OpenSearchSearchType(str, Enum):
"""Search type label used for Prometheus metrics."""
HYBRID = "hybrid"
KEYWORD = "keyword"
SEMANTIC = "semantic"
RANDOM = "random"
ID_RETRIEVAL = "id_retrieval"
DOCUMENT_IDS = "document_ids"
UNKNOWN = "unknown"
class HybridSearchSubqueryConfiguration(Enum):
TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD = 1
# Current default.

View File

@@ -1,11 +1,12 @@
import json
from collections import defaultdict
from collections.abc import Iterable
from typing import Any
import httpx
from opensearchpy import NotFoundError
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -43,7 +44,6 @@ from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import SearchHit
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
from onyx.document_index.opensearch.constants import OpenSearchSearchType
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
@@ -351,7 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
@@ -647,10 +647,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata, # noqa: ARG002
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
"""Indexes a list of document chunks into the document index.
"""Indexes an iterable of document chunks into the document index.
Groups chunks by document ID and for each document, deletes existing
chunks and indexes the new chunks in bulk.
@@ -673,29 +673,34 @@ class OpenSearchDocumentIndex(DocumentIndex):
document is newly indexed or had already existed and was just
updated.
"""
# Group chunks by document ID.
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
list
total_chunks = sum(
cc.new_chunk_cnt
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
)
for chunk in chunks:
doc_id_to_chunks[chunk.source_document.id].append(chunk)
logger.debug(
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
f"documents for index {self._index_name}."
)
document_indexing_results: list[DocumentInsertionRecord] = []
# Try to index per-document.
for _, chunks in doc_id_to_chunks.items():
deleted_doc_ids: set[str] = set()
# Buffer chunks per document as they arrive from the iterable.
# When the document ID changes flush the buffered chunks.
current_doc_id: str | None = None
current_chunks: list[DocMetadataAwareIndexChunk] = []
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
assert len(doc_chunks) > 0, "doc_chunks is empty"
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
# Do this before deleting existing chunks to reduce the amount of
# time the document index has no content for a given document, and
# to reduce the chance of entering a state where we delete chunks,
# then some error happens, and never successfully index new chunks.
# Since we are doing this in batches, an error occurring midway
# can result in a state where chunks are deleted and not all the
# new chunks have been indexed.
chunk_batch: list[DocumentChunk] = [
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
_convert_onyx_chunk_to_opensearch_document(chunk)
for chunk in doc_chunks
]
onyx_document: Document = chunks[0].source_document
onyx_document: Document = doc_chunks[0].source_document
# First delete the doc's chunks from the index. This is so that
# there are no dangling chunks in the index, in the event that the
# new document's content contains fewer chunks than the previous
@@ -704,22 +709,43 @@ class OpenSearchDocumentIndex(DocumentIndex):
# if the chunk count has actually decreased. This assumes that
# overlapping chunks are perfectly overwritten. If we can't
# guarantee that then we need the code as-is.
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
# If we see that chunks were deleted we assume the doc already
# existed.
document_insertion_record = DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
if onyx_document.id not in deleted_doc_ids:
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
deleted_doc_ids.add(onyx_document.id)
# If we see that chunks were deleted we assume the doc already
# existed. We record the result before bulk_index_documents
# runs. If indexing raises, this entire result list is discarded
# by the caller's retry logic, so early recording is safe.
document_indexing_results.append(
DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
)
# Now index. This will raise if a chunk of the same ID exists, which
# we do not expect because we should have deleted all chunks.
self._client.bulk_index_documents(
documents=chunk_batch,
tenant_state=self._tenant_state,
)
document_indexing_results.append(document_insertion_record)
for chunk in chunks:
doc_id = chunk.source_document.id
if doc_id != current_doc_id:
if current_chunks:
_flush_chunks(current_chunks)
current_doc_id = doc_id
current_chunks = [chunk]
elif len(current_chunks) >= MAX_CHUNKS_PER_DOC_BATCH:
_flush_chunks(current_chunks)
current_chunks = [chunk]
else:
current_chunks.append(chunk)
if current_chunks:
_flush_chunks(current_chunks)
return document_indexing_results
@@ -901,7 +927,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_hits = self._client.search(
body=query_body,
search_pipeline_id=None,
search_type=OpenSearchSearchType.ID_RETRIEVAL,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
@@ -925,8 +950,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
# TODO(andrei): There is some duplicated logic in this function with
# others in this file.
logger.debug(
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
@@ -952,7 +975,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=normalization_pipeline_name,
search_type=OpenSearchSearchType.HYBRID,
)
# Good place for a breakpoint to inspect the search hits if you have
@@ -975,8 +997,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
# TODO(andrei): There is some duplicated logic in this function with
# others in this file.
logger.debug(
f"[OpenSearchDocumentIndex] Keyword retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
@@ -996,7 +1016,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
search_type=OpenSearchSearchType.KEYWORD,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
@@ -1017,8 +1036,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
# TODO(andrei): There is some duplicated logic in this function with
# others in this file.
logger.debug(
f"[OpenSearchDocumentIndex] Semantic retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
@@ -1038,7 +1055,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
search_type=OpenSearchSearchType.SEMANTIC,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
@@ -1070,7 +1086,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
search_type=OpenSearchSearchType.RANDOM,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(

View File

@@ -3,8 +3,6 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import TypeAlias
from typing import TypeVar
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
@@ -50,21 +48,13 @@ from onyx.document_index.opensearch.schema import TITLE_FIELD_NAME
from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
# See https://docs.opensearch.org/latest/query-dsl/term/terms/.
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY = 65_536
_T = TypeVar("_T")
TermsQuery: TypeAlias = dict[str, dict[str, list[_T]]]
TermQuery: TypeAlias = dict[str, dict[str, dict[str, _T]]]
# Normalization pipelines combine document scores from multiple query clauses.
# The number and ordering of weights should match the query clauses. The values
# of the weights should sum to 1.
# TODO(andrei): Turn all magic dictionaries to pydantic models.
# Normalization pipelines combine document scores from multiple query clauses.
# The number and ordering of weights should match the query clauses. The values
# of the weights should sum to 1.
def _get_hybrid_search_normalization_weights() -> list[float]:
if (
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
@@ -326,9 +316,6 @@ class DocumentQuery:
it MUST be supplied in addition to a search pipeline. The results from
hybrid search are not meaningful without that step.
TODO(andrei): There is some duplicated logic in this function with
others in this file.
Args:
query_text: The text to query for.
query_vector: The vector embedding of the text to query for.
@@ -432,9 +419,6 @@ class DocumentQuery:
This query can be directly supplied to the OpenSearch client.
TODO(andrei): There is some duplicated logic in this function with
others in this file.
Args:
query_text: The text to query for.
num_hits: The final number of hits to return.
@@ -514,9 +498,6 @@ class DocumentQuery:
This query can be directly supplied to the OpenSearch client.
TODO(andrei): There is some duplicated logic in this function with
others in this file.
Args:
query_embedding: The vector embedding of the text to query for.
num_hits: The final number of hits to return.
@@ -782,9 +763,8 @@ class DocumentQuery:
TITLE_FIELD_NAME: {
"query": query_text,
"operator": "or",
# The title fields are strongly discounted as
# they are included in the content. This just
# acts as a minor boost.
# The title fields are strongly discounted as they are included in the content.
# It just acts as a minor boost
"boost": 0.1,
}
}
@@ -799,9 +779,6 @@ class DocumentQuery:
}
},
{
# Analyzes the query and returns results which match any
# of the query's terms. More matches result in higher
# scores.
"match": {
CONTENT_FIELD_NAME: {
"query": query_text,
@@ -811,21 +788,18 @@ class DocumentQuery:
}
},
{
# Matches an exact phrase in a specified order.
"match_phrase": {
CONTENT_FIELD_NAME: {
"query": query_text,
# The number of words permitted between words of
# a query phrase and still result in a match.
"slop": 1,
"boost": 1.5,
}
}
},
],
# Ensures at least one match subquery from the query is present
# in the document. This defaults to 1, unless a filter or must
# clause is supplied, in which case it defaults to 0.
# Ensure at least one term from the query is present in the
# document. This defaults to 1, unless a filter or must clause
# is supplied, in which case it defaults to 0.
"minimum_should_match": 1,
}
}
@@ -859,14 +833,7 @@ class DocumentQuery:
The "filter" key applies a logical AND operator to its elements, so
every subfilter must evaluate to true in order for the document to be
retrieved. This function returns a list of such subfilters.
See https://docs.opensearch.org/latest/query-dsl/compound/bool/.
TODO(ENG-3874): The terms queries returned by this function can be made
more performant for large cardinality sets by sorting the values by
their UTF-8 byte order.
TODO(ENG-3875): This function can take even better advantage of filter
caching by grouping "static" filters together into one sub-clause.
See https://docs.opensearch.org/latest/query-dsl/compound/bool/
Args:
tenant_state: Tenant state containing the tenant ID.
@@ -911,14 +878,6 @@ class DocumentQuery:
the assistant. Matches chunks where ancestor_hierarchy_node_ids
contains any of these values.
Raises:
ValueError: document_id and attached_document_ids were supplied
together. This is not allowed because they operate on the same
schema field, and it does not semantically make sense to use
them together.
ValueError: Too many of one of the collection arguments was
supplied.
Returns:
A list of filters to be passed into the "filter" key of a search
query.
@@ -926,156 +885,61 @@ class DocumentQuery:
def _get_acl_visibility_filter(
access_control_list: list[str],
) -> dict[str, dict[str, list[TermQuery[bool] | TermsQuery[str]] | int]]:
"""Returns a filter for the access control list.
Since this returns an isolated bool should clause, it can be cached
in OpenSearch independently of other clauses in _get_search_filters.
Args:
access_control_list: The access control list to restrict
documents to.
Raises:
ValueError: The number of access control list entries is greater
than MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
Returns:
A filter for the access control list.
"""
) -> dict[str, Any]:
# Logical OR operator on its elements.
acl_visibility_filter: dict[str, dict[str, Any]] = {
"bool": {
"should": [{"term": {PUBLIC_FIELD_NAME: {"value": True}}}],
"minimum_should_match": 1,
}
}
if access_control_list:
if len(access_control_list) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
raise ValueError(
f"Too many access control list entries: {len(access_control_list)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
)
# Use terms instead of a list of term within a should clause
# because Lucene will optimize the filtering for large sets of
# terms. Small sets of terms are not expected to perform any
# differently than individual term clauses.
acl_subclause: TermsQuery[str] = {
"terms": {ACCESS_CONTROL_LIST_FIELD_NAME: list(access_control_list)}
acl_visibility_filter: dict[str, Any] = {"bool": {"should": []}}
acl_visibility_filter["bool"]["should"].append(
{"term": {PUBLIC_FIELD_NAME: {"value": True}}}
)
for acl in access_control_list:
acl_subclause: dict[str, Any] = {
"term": {ACCESS_CONTROL_LIST_FIELD_NAME: {"value": acl}}
}
acl_visibility_filter["bool"]["should"].append(acl_subclause)
return acl_visibility_filter
def _get_source_type_filter(
source_types: list[DocumentSource],
) -> TermsQuery[str]:
"""Returns a filter for the source types.
Since this returns an isolated terms clause, it can be cached in
OpenSearch independently of other clauses in _get_search_filters.
Args:
source_types: The source types to restrict documents to.
Raises:
ValueError: The number of source types is greater than
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
ValueError: An empty list was supplied.
Returns:
A filter for the source types.
"""
if not source_types:
raise ValueError(
"source_types cannot be empty if trying to create a source type filter."
) -> dict[str, Any]:
# Logical OR operator on its elements.
source_type_filter: dict[str, Any] = {"bool": {"should": []}}
for source_type in source_types:
source_type_filter["bool"]["should"].append(
{"term": {SOURCE_TYPE_FIELD_NAME: {"value": source_type.value}}}
)
if len(source_types) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
raise ValueError(
f"Too many source types: {len(source_types)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
return source_type_filter
def _get_tag_filter(tags: list[Tag]) -> dict[str, Any]:
# Logical OR operator on its elements.
tag_filter: dict[str, Any] = {"bool": {"should": []}}
for tag in tags:
# Kind of an abstraction leak, see
# convert_metadata_dict_to_list_of_strings for why metadata list
# entries are expected to look this way.
tag_str = f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}"
tag_filter["bool"]["should"].append(
{"term": {METADATA_LIST_FIELD_NAME: {"value": tag_str}}}
)
# Use terms instead of a list of term within a should clause because
# Lucene will optimize the filtering for large sets of terms. Small
# sets of terms are not expected to perform any differently than
# individual term clauses.
return {
"terms": {
SOURCE_TYPE_FIELD_NAME: [
source_type.value for source_type in source_types
]
}
}
return tag_filter
def _get_tag_filter(tags: list[Tag]) -> TermsQuery[str]:
"""Returns a filter for the tags.
Since this returns an isolated terms clause, it can be cached in
OpenSearch independently of other clauses in _get_search_filters.
Args:
tags: The tags to restrict documents to.
Raises:
ValueError: The number of tags is greater than
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
ValueError: An empty list was supplied.
Returns:
A filter for the tags.
"""
if not tags:
raise ValueError(
"tags cannot be empty if trying to create a tag filter."
def _get_document_set_filter(document_sets: list[str]) -> dict[str, Any]:
# Logical OR operator on its elements.
document_set_filter: dict[str, Any] = {"bool": {"should": []}}
for document_set in document_sets:
document_set_filter["bool"]["should"].append(
{"term": {DOCUMENT_SETS_FIELD_NAME: {"value": document_set}}}
)
if len(tags) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
raise ValueError(
f"Too many tags: {len(tags)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
)
# Kind of an abstraction leak, see
# convert_metadata_dict_to_list_of_strings for why metadata list
# entries are expected to look this way.
tag_str_list = [
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in tags
]
# Use terms instead of a list of term within a should clause because
# Lucene will optimize the filtering for large sets of terms. Small
# sets of terms are not expected to perform any differently than
# individual term clauses.
return {"terms": {METADATA_LIST_FIELD_NAME: tag_str_list}}
return document_set_filter
def _get_document_set_filter(document_sets: list[str]) -> TermsQuery[str]:
"""Returns a filter for the document sets.
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
# Logical OR operator on its elements.
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
user_project_filter["bool"]["should"].append(
{"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
)
return user_project_filter
Since this returns an isolated terms clause, it can be cached in
OpenSearch independently of other clauses in _get_search_filters.
Args:
document_sets: The document sets to restrict documents to.
Raises:
ValueError: The number of document sets is greater than
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
ValueError: An empty list was supplied.
Returns:
A filter for the document sets.
"""
if not document_sets:
raise ValueError(
"document_sets cannot be empty if trying to create a document set filter."
)
if len(document_sets) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
raise ValueError(
f"Too many document sets: {len(document_sets)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
)
# Use terms instead of a list of term within a should clause because
# Lucene will optimize the filtering for large sets of terms. Small
# sets of terms are not expected to perform any differently than
# individual term clauses.
return {"terms": {DOCUMENT_SETS_FIELD_NAME: list(document_sets)}}
def _get_user_project_filter(project_id: int) -> TermQuery[int]:
return {"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
def _get_persona_filter(persona_id: int) -> TermQuery[int]:
def _get_persona_filter(persona_id: int) -> dict[str, Any]:
return {"term": {PERSONAS_FIELD_NAME: {"value": persona_id}}}
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
@@ -1083,9 +947,7 @@ class DocumentQuery:
# document data.
time_cutoff = set_or_convert_timezone_to_utc(time_cutoff)
# Logical OR operator on its elements.
time_cutoff_filter: dict[str, Any] = {
"bool": {"should": [], "minimum_should_match": 1}
}
time_cutoff_filter: dict[str, Any] = {"bool": {"should": []}}
time_cutoff_filter["bool"]["should"].append(
{
"range": {
@@ -1120,77 +982,25 @@ class DocumentQuery:
def _get_attached_document_id_filter(
doc_ids: list[str],
) -> TermsQuery[str]:
"""
Returns a filter for documents explicitly attached to an assistant.
Since this returns an isolated terms clause, it can be cached in
OpenSearch independently of other clauses in _get_search_filters.
Args:
doc_ids: The document IDs to restrict documents to.
Raises:
ValueError: The number of document IDs is greater than
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
ValueError: An empty list was supplied.
Returns:
A filter for the document IDs.
"""
if not doc_ids:
raise ValueError(
"doc_ids cannot be empty if trying to create a document ID filter."
) -> dict[str, Any]:
"""Filter for documents explicitly attached to an assistant."""
# Logical OR operator on its elements.
doc_id_filter: dict[str, Any] = {"bool": {"should": []}}
for doc_id in doc_ids:
doc_id_filter["bool"]["should"].append(
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": doc_id}}}
)
if len(doc_ids) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
raise ValueError(
f"Too many document IDs: {len(doc_ids)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
)
# Use terms instead of a list of term within a should clause because
# Lucene will optimize the filtering for large sets of terms. Small
# sets of terms are not expected to perform any differently than
# individual term clauses.
return {"terms": {DOCUMENT_ID_FIELD_NAME: list(doc_ids)}}
return doc_id_filter
def _get_hierarchy_node_filter(
node_ids: list[int],
) -> TermsQuery[int]:
) -> dict[str, Any]:
"""Filter for chunks whose ancestors include any of the given hierarchy nodes.
Uses a terms query to check if ancestor_hierarchy_node_ids contains
any of the specified node IDs.
"""
Returns a filter for chunks whose ancestors include any of the given
hierarchy nodes.
Since this returns an isolated terms clause, it can be cached in
OpenSearch independently of other clauses in _get_search_filters.
Args:
node_ids: The hierarchy node IDs to restrict documents to.
Raises:
ValueError: The number of hierarchy node IDs is greater than
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
ValueError: An empty list was supplied.
Returns:
A filter for the hierarchy node IDs.
"""
if not node_ids:
raise ValueError(
"node_ids cannot be empty if trying to create a hierarchy node ID filter."
)
if len(node_ids) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
raise ValueError(
f"Too many hierarchy node IDs: {len(node_ids)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
)
# Use terms instead of a list of term within a should clause because
# Lucene will optimize the filtering for large sets of terms. Small
# sets of terms are not expected to perform any differently than
# individual term clauses.
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: list(node_ids)}}
if document_id is not None and attached_document_ids is not None:
raise ValueError(
"document_id and attached_document_ids cannot be used together."
)
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
filter_clauses: list[dict[str, Any]] = []
@@ -1235,9 +1045,6 @@ class DocumentQuery:
)
if has_knowledge_scope:
# Since this returns an isolated bool should clause, it can be
# cached in OpenSearch independently of other clauses in
# _get_search_filters.
knowledge_filter: dict[str, Any] = {
"bool": {"should": [], "minimum_should_match": 1}
}

View File

@@ -6,6 +6,7 @@ import re
import time
import urllib
import zipfile
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
@@ -461,7 +462,7 @@ class VespaIndex(DocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""

View File

@@ -1,6 +1,8 @@
import concurrent.futures
import logging
import random
from collections.abc import Generator
from collections.abc import Iterable
from typing import Any
from uuid import UUID
@@ -8,6 +10,7 @@ import httpx
from pydantic import BaseModel
from retry import retry
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
from onyx.configs.app_configs import RERANK_COUNT
from onyx.configs.chat_configs import DOC_TIME_DECAY
@@ -318,7 +321,7 @@ class VespaDocumentIndex(DocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
@@ -338,22 +341,31 @@ class VespaDocumentIndex(DocumentIndex):
# Vespa has restrictions on valid characters, yet document IDs come from
# external w.r.t. this class. We need to sanitize them.
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
clean_chunk_id_copy(chunk) for chunk in chunks
]
assert len(cleaned_chunks) == len(
chunks
), "Bug: Cleaned chunks and input chunks have different lengths."
#
# Instead of materializing all cleaned chunks upfront, we stream them
# through a generator that cleans IDs and builds the original-ID mapping
# incrementally as chunks flow into Vespa.
def _clean_and_track(
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
id_map: dict[str, str],
seen_ids: set[str],
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
"""Cleans chunk IDs and builds the original-ID mapping
incrementally as chunks flow through, avoiding a separate
materialization pass."""
for chunk in chunks_iter:
original_id = chunk.source_document.id
cleaned = clean_chunk_id_copy(chunk)
cleaned_id = cleaned.source_document.id
# Needed so the final DocumentInsertionRecord returned can have
# the original document ID. cleaned_chunks might not contain IDs
# exactly as callers supplied them.
id_map[cleaned_id] = original_id
seen_ids.add(cleaned_id)
yield cleaned
# Needed so the final DocumentInsertionRecord returned can have the
# original document ID. cleaned_chunks might not contain IDs exactly as
# callers supplied them.
new_document_id_to_original_document_id: dict[str, str] = dict()
for i, cleaned_chunk in enumerate(cleaned_chunks):
old_chunk = chunks[i]
new_document_id_to_original_document_id[
cleaned_chunk.source_document.id
] = old_chunk.source_document.id
new_document_id_to_original_document_id: dict[str, str] = {}
all_cleaned_doc_ids: set[str] = set()
existing_docs: set[str] = set()
@@ -409,8 +421,16 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
# Insert new Vespa documents.
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
# Insert new Vespa documents, streaming through the cleaning
# pipeline so chunks are never fully materialized.
cleaned_chunks = _clean_and_track(
chunks,
new_document_id_to_original_document_id,
all_cleaned_doc_ids,
)
for chunk_batch in batch_generator(
cleaned_chunks, min(BATCH_SIZE, MAX_CHUNKS_PER_DOC_BATCH)
):
batch_index_vespa_chunks(
chunks=chunk_batch,
index_name=self._index_name,
@@ -419,10 +439,6 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
all_cleaned_doc_ids: set[str] = {
chunk.source_document.id for chunk in cleaned_chunks
}
return [
DocumentInsertionRecord(
document_id=new_document_id_to_original_document_id[cleaned_doc_id],
@@ -610,22 +626,6 @@ class VespaDocumentIndex(DocumentIndex):
return cleanup_content_for_chunks(query_vespa(params))
def keyword_retrieval(
self,
query: str,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
raise NotImplementedError
def semantic_retrieval(
self,
query_embedding: Embedding,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
raise NotImplementedError
def random_retrieval(
self,
filters: IndexFilters,

View File

@@ -44,7 +44,6 @@ class OnyxErrorCode(Enum):
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
INVALID_INPUT = ("INVALID_INPUT", 400)
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
QUERY_REJECTED = ("QUERY_REJECTED", 400)
# ------------------------------------------------------------------
# Not Found (404)

View File

@@ -5,7 +5,6 @@ Usage (Celery tasks and FastAPI handlers):
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
response_type=QueryProcessingResponse,
)
if isinstance(result, HookSkipped):
@@ -15,7 +14,7 @@ Usage (Celery tasks and FastAPI handlers):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is a validated Pydantic model instance (response_type)
# result is the response payload dict from the customer's endpoint
...
is_reachable update policy
@@ -54,11 +53,9 @@ The executor uses three sessions:
import json
import time
from typing import Any
from typing import TypeVar
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -84,9 +81,6 @@ class HookSoftFailed:
"""Hook was called but failed with SOFT fail strategy — continuing."""
T = TypeVar("T", bound=BaseModel)
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
@@ -274,21 +268,22 @@ def _persist_result(
# ---------------------------------------------------------------------------
def _execute_hook_inner(
hook: Hook,
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously."""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
@@ -305,36 +300,13 @@ def _execute_hook_inner(
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
with httpx.Client(timeout=timeout) as client:
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against response_type.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: T | None = None
if outcome.is_success and outcome.response_payload is not None:
try:
validated_model = response_type.model_validate(outcome.response_payload)
except ValidationError as e:
msg = (
f"Hook response failed validation against {response_type.__name__}: {e}"
)
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
@@ -351,41 +323,8 @@ def _execute_hook_inner(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if validated_model is None:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"validated_model is None for successful hook call (hook_id={hook_id})",
if outcome.response_payload is None:
raise ValueError(
f"response_payload is None for successful hook call (hook_id={hook_id})"
)
return validated_model
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
"""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
fail_strategy = hook.fail_strategy
hook_id = hook.id
try:
return _execute_hook_inner(hook, payload, response_type)
except Exception:
if fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
)
return HookSoftFailed()
raise
return outcome.response_payload

View File

@@ -91,8 +91,6 @@ class HookResponse(BaseModel):
# Nullable to match the DB column — endpoint_url is required on creation but
# future hook point types may not use an external endpoint (e.g. built-in handlers).
endpoint_url: str | None
# Partially-masked API key (e.g. "abcd••••••••wxyz"), or None if no key is set.
api_key_masked: str | None
fail_strategy: HookFailStrategy
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
is_active: bool

View File

@@ -51,12 +51,13 @@ class HookPointSpec:
output_schema: ClassVar[dict[str, Any]]
def __init_subclass__(cls, **kwargs: object) -> None:
"""Enforce that every subclass declares all required class attributes.
"""Enforce that every concrete subclass declares all required class attributes.
Called automatically by Python whenever a class inherits from HookPointSpec.
Raises TypeError at import time if any required attribute is missing or if
payload_model / response_model are not Pydantic BaseModel subclasses.
input_schema and output_schema are derived automatically from the models.
Abstract subclasses (those still carrying unimplemented abstract methods) are
skipped — they are intermediate base classes and may not yet define everything.
Only fully concrete subclasses are validated, ensuring a clear TypeError at
import time rather than a confusing AttributeError at runtime.
"""
super().__init_subclass__(**kwargs)
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]

View File

@@ -26,8 +26,6 @@ class DocumentIngestionSpec(HookPointSpec):
default_timeout_seconds = 30.0
fail_hard_description = "The document will not be indexed."
default_fail_strategy = HookFailStrategy.HARD
# TODO(Bo-Onyx): update later
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.ue263ual5vdi"
payload_model = DocumentIngestionPayload
response_model = DocumentIngestionResponse

View File

@@ -15,7 +15,7 @@ class QueryProcessingPayload(BaseModel):
description="Email of the user submitting the query, or null if unauthenticated."
)
chat_session_id: str = Field(
description="UUID of the chat session, formatted as a hyphenated lowercase string (e.g. '550e8400-e29b-41d4-a716-446655440000'). Always present — the session is guaranteed to exist by the time this hook fires."
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
)
@@ -25,7 +25,7 @@ class QueryProcessingResponse(BaseModel):
default=None,
description=(
"The query to use in the pipeline. "
"Null, empty string, whitespace-only, or absent = reject the query."
"Null, empty string, or absent = reject the query."
),
)
rejection_message: str | None = Field(
@@ -65,8 +65,6 @@ class QueryProcessingSpec(HookPointSpec):
"The query will be blocked and the user will see an error message."
)
default_fail_strategy = HookFailStrategy.HARD
# TODO(Bo-Onyx): update later
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.g2r1a1699u87"
payload_model = QueryProcessingPayload
response_model = QueryProcessingResponse

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import contextlib
from collections.abc import Generator
@@ -19,7 +21,8 @@ from onyx.db.document import update_docs_updated_at__no_commit
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import ChunkEnrichmentContext
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
@@ -85,14 +88,21 @@ class DocumentIndexingBatchAdapter:
) as transaction:
yield transaction
def build_metadata_aware_chunks(
def prepare_enrichment(
self,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
context: DocumentBatchPrepareContext,
) -> BuildMetadataAwareChunksResult:
"""Enrich chunks with access, document sets, boosts, token counts, and hierarchy."""
tenant_id: str,
chunks: list[DocAwareChunk],
) -> DocumentChunkEnricher:
"""Do all DB lookups once and return a per-chunk enricher."""
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_new_chunk_cnt: dict[str, int] = {
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
no_access = DocumentAccess.build(
user_emails=[],
@@ -102,67 +112,30 @@ class DocumentIndexingBatchAdapter:
is_public=False,
)
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_access_info = get_access_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
doc_id_to_document_set = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
return DocumentChunkEnricher(
doc_id_to_access_info=get_access_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
}
doc_id_to_previous_chunk_cnt: dict[str, int] = {
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
db_session=self.db_session,
)
}
doc_id_to_new_chunk_cnt: dict[str, int] = {
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks_with_embeddings:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
# Get ancestor hierarchy node IDs for each document
doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents(
context.updatable_docs, tenant_id
)
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=doc_id_to_access_info.get(chunk.source_document.id, no_access),
document_sets=set(
doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
context.id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in context.id_to_boost_map
else DEFAULT_BOOST
),
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
ancestor_hierarchy_node_ids=doc_id_to_ancestor_ids[
chunk.source_document.id
],
)
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
return BuildMetadataAwareChunksResult(
chunks=access_aware_chunks,
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=doc_id_to_new_chunk_cnt,
user_file_id_to_raw_text={},
user_file_id_to_token_count={},
),
doc_id_to_document_set={
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
},
doc_id_to_ancestor_ids=self._get_ancestor_ids_for_documents(
context.updatable_docs, tenant_id
),
id_to_boost_map=context.id_to_boost_map,
doc_id_to_previous_chunk_cnt={
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
db_session=self.db_session,
)
},
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
no_access=no_access,
tenant_id=tenant_id,
)
def _get_ancestor_ids_for_documents(
@@ -203,7 +176,7 @@ class DocumentIndexingBatchAdapter:
context: DocumentBatchPrepareContext,
updatable_chunk_data: list[UpdatableChunkData],
filtered_documents: list[Document],
result: BuildMetadataAwareChunksResult,
enrichment: ChunkEnrichmentContext,
) -> None:
"""Finalize DB updates, store plaintext, and mark docs as indexed."""
updatable_ids = [doc.id for doc in context.updatable_docs]
@@ -227,7 +200,7 @@ class DocumentIndexingBatchAdapter:
update_docs_chunk_count__no_commit(
document_ids=updatable_ids,
doc_id_to_chunk_count=result.doc_id_to_new_chunk_cnt,
doc_id_to_chunk_count=enrichment.doc_id_to_new_chunk_cnt,
db_session=self.db_session,
)
@@ -249,3 +222,52 @@ class DocumentIndexingBatchAdapter:
)
self.db_session.commit()
class DocumentChunkEnricher:
"""Pre-computed metadata for per-chunk enrichment of connector documents."""
def __init__(
self,
doc_id_to_access_info: dict[str, DocumentAccess],
doc_id_to_document_set: dict[str, list[str]],
doc_id_to_ancestor_ids: dict[str, list[int]],
id_to_boost_map: dict[str, int],
doc_id_to_previous_chunk_cnt: dict[str, int],
doc_id_to_new_chunk_cnt: dict[str, int],
no_access: DocumentAccess,
tenant_id: str,
) -> None:
self._doc_id_to_access_info = doc_id_to_access_info
self._doc_id_to_document_set = doc_id_to_document_set
self._doc_id_to_ancestor_ids = doc_id_to_ancestor_ids
self._id_to_boost_map = id_to_boost_map
self._no_access = no_access
self._tenant_id = tenant_id
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk:
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=self._doc_id_to_access_info.get(
chunk.source_document.id, self._no_access
),
document_sets=set(
self._doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
self._id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in self._id_to_boost_map
else DEFAULT_BOOST
),
tenant_id=self._tenant_id,
aggregated_chunk_boost_factor=score,
ancestor_hierarchy_node_ids=self._doc_id_to_ancestor_ids[
chunk.source_document.id
],
)

View File

@@ -1,6 +1,9 @@
from __future__ import annotations
import contextlib
import datetime
import time
from collections import defaultdict
from collections.abc import Generator
from uuid import UUID
@@ -24,7 +27,8 @@ from onyx.db.user_file import fetch_persona_ids_for_user_files
from onyx.db.user_file import fetch_user_project_ids_for_user_files
from onyx.file_store.utils import store_user_file_plaintext
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import ChunkEnrichmentContext
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
@@ -101,13 +105,20 @@ class UserFileIndexingAdapter:
f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts for user files: {[doc.id for doc in documents]}"
)
def build_metadata_aware_chunks(
def prepare_enrichment(
self,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
context: DocumentBatchPrepareContext,
) -> BuildMetadataAwareChunksResult:
tenant_id: str,
chunks: list[DocAwareChunk],
) -> UserFileChunkEnricher:
"""Do all DB lookups and pre-compute file metadata from chunks."""
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_new_chunk_cnt: dict[str, int] = defaultdict(int)
content_by_file: dict[str, list[str]] = defaultdict(list)
for chunk in chunks:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
content_by_file[chunk.source_document.id].append(chunk.content)
no_access = DocumentAccess.build(
user_emails=[],
@@ -117,7 +128,6 @@ class UserFileIndexingAdapter:
is_public=False,
)
updatable_ids = [doc.id for doc in context.updatable_docs]
user_file_id_to_project_ids = fetch_user_project_ids_for_user_files(
user_file_ids=updatable_ids,
db_session=self.db_session,
@@ -138,17 +148,6 @@ class UserFileIndexingAdapter:
)
}
user_file_id_to_new_chunk_cnt: dict[str, int] = {
user_file_id: len(
[
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == user_file_id
]
)
for user_file_id in updatable_ids
}
# Initialize tokenizer used for token count calculation
try:
llm = get_default_llm()
@@ -163,15 +162,9 @@ class UserFileIndexingAdapter:
user_file_id_to_raw_text: dict[str, str] = {}
user_file_id_to_token_count: dict[str, int | None] = {}
for user_file_id in updatable_ids:
user_file_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == user_file_id
]
if user_file_chunks:
combined_content = " ".join(
[chunk.content for chunk in user_file_chunks]
)
contents = content_by_file.get(user_file_id)
if contents:
combined_content = " ".join(contents)
user_file_id_to_raw_text[str(user_file_id)] = combined_content
token_count = (
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
@@ -181,28 +174,16 @@ class UserFileIndexingAdapter:
user_file_id_to_raw_text[str(user_file_id)] = ""
user_file_id_to_token_count[str(user_file_id)] = None
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=user_file_id_to_access.get(chunk.source_document.id, no_access),
document_sets=set(),
user_project=user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
boost=DEFAULT_BOOST,
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
)
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
return BuildMetadataAwareChunksResult(
chunks=access_aware_chunks,
return UserFileChunkEnricher(
user_file_id_to_access=user_file_id_to_access,
user_file_id_to_project_ids=user_file_id_to_project_ids,
user_file_id_to_persona_ids=user_file_id_to_persona_ids,
doc_id_to_previous_chunk_cnt=user_file_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=user_file_id_to_new_chunk_cnt,
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
user_file_id_to_raw_text=user_file_id_to_raw_text,
user_file_id_to_token_count=user_file_id_to_token_count,
no_access=no_access,
tenant_id=tenant_id,
)
def _notify_assistant_owners_if_files_ready(
@@ -246,8 +227,9 @@ class UserFileIndexingAdapter:
context: DocumentBatchPrepareContext,
updatable_chunk_data: list[UpdatableChunkData], # noqa: ARG002
filtered_documents: list[Document], # noqa: ARG002
result: BuildMetadataAwareChunksResult,
enrichment: ChunkEnrichmentContext,
) -> None:
assert isinstance(enrichment, UserFileChunkEnricher)
user_file_ids = [doc.id for doc in context.updatable_docs]
user_files = (
@@ -263,8 +245,10 @@ class UserFileIndexingAdapter:
user_file.last_project_sync_at = datetime.datetime.now(
datetime.timezone.utc
)
user_file.chunk_count = result.doc_id_to_new_chunk_cnt[str(user_file.id)]
user_file.token_count = result.user_file_id_to_token_count[
user_file.chunk_count = enrichment.doc_id_to_new_chunk_cnt.get(
str(user_file.id), 0
)
user_file.token_count = enrichment.user_file_id_to_token_count[
str(user_file.id)
]
@@ -276,8 +260,54 @@ class UserFileIndexingAdapter:
# Store the plaintext in the file store for faster retrieval
# NOTE: this creates its own session to avoid committing the overall
# transaction.
for user_file_id, raw_text in result.user_file_id_to_raw_text.items():
for user_file_id, raw_text in enrichment.user_file_id_to_raw_text.items():
store_user_file_plaintext(
user_file_id=UUID(user_file_id),
plaintext_content=raw_text,
)
class UserFileChunkEnricher:
"""Pre-computed metadata for per-chunk enrichment of user-uploaded files."""
def __init__(
self,
user_file_id_to_access: dict[str, DocumentAccess],
user_file_id_to_project_ids: dict[str, list[int]],
user_file_id_to_persona_ids: dict[str, list[int]],
doc_id_to_previous_chunk_cnt: dict[str, int],
doc_id_to_new_chunk_cnt: dict[str, int],
user_file_id_to_raw_text: dict[str, str],
user_file_id_to_token_count: dict[str, int | None],
no_access: DocumentAccess,
tenant_id: str,
) -> None:
self._user_file_id_to_access = user_file_id_to_access
self._user_file_id_to_project_ids = user_file_id_to_project_ids
self._user_file_id_to_persona_ids = user_file_id_to_persona_ids
self._no_access = no_access
self._tenant_id = tenant_id
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
self.user_file_id_to_raw_text = user_file_id_to_raw_text
self.user_file_id_to_token_count = user_file_id_to_token_count
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk:
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=self._user_file_id_to_access.get(
chunk.source_document.id, self._no_access
),
document_sets=set(),
user_project=self._user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=self._user_file_id_to_persona_ids.get(
chunk.source_document.id, []
),
boost=DEFAULT_BOOST,
tenant_id=self._tenant_id,
aggregated_chunk_boost_factor=score,
)

View File

@@ -1,5 +1,7 @@
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Iterable
from typing import cast
from typing import Protocol
from pydantic import BaseModel
@@ -47,6 +49,7 @@ from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import embed_chunks_with_failure_handling
from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexingBatchAdapter
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
@@ -91,6 +94,15 @@ class IndexingPipelineResult(BaseModel):
failures: list[ConnectorFailure]
@classmethod
def empty(cls, total_docs: int) -> "IndexingPipelineResult":
return cls(
new_docs=0,
total_docs=total_docs,
total_chunks=0,
failures=[],
)
class IndexingPipelineProtocol(Protocol):
def __call__(
@@ -672,12 +684,7 @@ def index_doc_batch(
filtered_documents = filter_fnc(document_batch)
context = adapter.prepare(filtered_documents, ignore_time_skip)
if not context:
return IndexingPipelineResult(
new_docs=0,
total_docs=len(filtered_documents),
total_chunks=0,
failures=[],
)
return IndexingPipelineResult.empty(len(filtered_documents))
# Convert documents to IndexingDocument objects with processed section
# logger.debug("Processing image sections")
@@ -748,19 +755,29 @@ def index_doc_batch(
# we still write data here for the immediate and most likely correct sync, but
# to resolve this, an update of the last modified field at the end of this loop
# always triggers a final metadata sync via the celery queue
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=chunks_with_embeddings,
chunk_content_scores=chunk_content_scores,
tenant_id=tenant_id,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=tenant_id,
chunks=cast(list[DocAwareChunk], chunks_with_embeddings),
)
short_descriptor_list = [chunk.to_short_descriptor() for chunk in result.chunks]
metadata_aware_chunks = [
enricher.enrich_chunk(chunk, score)
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
]
short_descriptor_list = [
chunk.to_short_descriptor() for chunk in metadata_aware_chunks
]
short_descriptor_log = str(short_descriptor_list)[:1024]
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = None
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = None
def chunk_iterable_creator() -> Iterable[DocMetadataAwareIndexChunk]:
return metadata_aware_chunks
for document_index in document_indices:
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
@@ -770,10 +787,10 @@ def index_doc_batch(
vector_db_write_failures,
) = write_chunks_to_vector_db_with_backoff(
document_index=document_index,
chunks=result.chunks,
make_chunks=chunk_iterable_creator,
index_batch_params=IndexBatchParams(
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
doc_id_to_previous_chunk_cnt=enricher.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=enricher.doc_id_to_new_chunk_cnt,
tenant_id=tenant_id,
large_chunks_enabled=chunker.enable_large_chunks,
),
@@ -802,7 +819,7 @@ def index_doc_batch(
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
"This should never happen."
f"This occured for document index {document_index.__class__.__name__}"
f"This occurred for document index {document_index.__class__.__name__}"
)
# We treat the first document index we got as the primary one used
# for reporting the state of indexing.
@@ -815,7 +832,7 @@ def index_doc_batch(
context=context,
updatable_chunk_data=updatable_chunk_data,
filtered_documents=filtered_documents,
result=result,
enrichment=enricher,
)
assert primary_doc_idx_insertion_records is not None

View File

@@ -235,12 +235,16 @@ class UpdatableChunkData(BaseModel):
boost_score: float
class BuildMetadataAwareChunksResult(BaseModel):
chunks: list[DocMetadataAwareIndexChunk]
class ChunkEnrichmentContext(Protocol):
"""Returned by prepare_enrichment. Holds pre-computed metadata lookups
and provides per-chunk enrichment."""
doc_id_to_previous_chunk_cnt: dict[str, int]
doc_id_to_new_chunk_cnt: dict[str, int]
user_file_id_to_raw_text: dict[str, str]
user_file_id_to_token_count: dict[str, int | None]
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk: ...
class IndexingBatchAdapter(Protocol):
@@ -254,18 +258,24 @@ class IndexingBatchAdapter(Protocol):
) -> Generator[TransactionalContext, None, None]:
"""Provide a transaction/row-lock context for critical updates."""
def build_metadata_aware_chunks(
def prepare_enrichment(
self,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
context: "DocumentBatchPrepareContext",
) -> BuildMetadataAwareChunksResult: ...
tenant_id: str,
chunks: list[DocAwareChunk],
) -> ChunkEnrichmentContext:
"""Prepare per-chunk enrichment data (access, document sets, boost, etc.).
Precondition: ``chunks`` have already been through the embedding step
(i.e. they are ``IndexChunk`` instances with populated embeddings,
passed here as the base ``DocAwareChunk`` type).
"""
...
def post_index(
self,
context: "DocumentBatchPrepareContext",
updatable_chunk_data: list[UpdatableChunkData],
filtered_documents: list[Document],
result: BuildMetadataAwareChunksResult,
enrichment: ChunkEnrichmentContext,
) -> None: ...

View File

@@ -1,6 +1,9 @@
import time
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Iterable
from http import HTTPStatus
from itertools import chain
from itertools import groupby
import httpx
@@ -28,22 +31,22 @@ def _log_insufficient_storage_error(e: Exception) -> None:
def write_chunks_to_vector_db_with_backoff(
document_index: DocumentIndex,
chunks: list[DocMetadataAwareIndexChunk],
make_chunks: Callable[[], Iterable[DocMetadataAwareIndexChunk]],
index_batch_params: IndexBatchParams,
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
goes document by document to isolate the failure(s).
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
vector DB interface assumes that all chunks for a single document are present.
vector DB interface assumes that all chunks for a single document are present. The
chunks must also be in contiguous batches
"""
# first try to write the chunks to the vector db
try:
return (
list(
document_index.index(
chunks=chunks,
chunks=make_chunks(),
index_batch_params=index_batch_params,
)
),
@@ -60,14 +63,16 @@ def write_chunks_to_vector_db_with_backoff(
# wait a couple seconds just to give the vector db a chance to recover
time.sleep(2)
# try writing each doc one by one
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
for chunk in chunks:
chunks_for_docs[chunk.source_document.id].append(chunk)
insertion_records: list[DocumentInsertionRecord] = []
failures: list[ConnectorFailure] = []
for doc_id, chunks_for_doc in chunks_for_docs.items():
def key(chunk: DocMetadataAwareIndexChunk) -> str:
return chunk.source_document.id
for doc_id, chunks_for_doc in groupby(make_chunks(), key=key):
first_chunk = next(chunks_for_doc)
chunks_for_doc = chain([first_chunk], chunks_for_doc)
try:
insertion_records.extend(
document_index.index(
@@ -87,9 +92,7 @@ def write_chunks_to_vector_db_with_backoff(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=(
chunks_for_doc[0].get_link() if chunks_for_doc else None
),
document_link=first_chunk.get_link(),
),
failure_message=str(e),
exception=e,

View File

@@ -25,7 +25,6 @@ class LlmProviderNames(str, Enum):
LM_STUDIO = "lm_studio"
MISTRAL = "mistral"
LITELLM_PROXY = "litellm_proxy"
BIFROST = "bifrost"
def __str__(self) -> str:
"""Needed so things like:
@@ -45,7 +44,6 @@ WELL_KNOWN_PROVIDER_NAMES = [
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
]
@@ -63,7 +61,6 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
LlmProviderNames.OLLAMA_CHAT: "Ollama",
LlmProviderNames.LM_STUDIO: "LM Studio",
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
LlmProviderNames.BIFROST: "Bifrost",
"groq": "Groq",
"anyscale": "Anyscale",
"deepseek": "DeepSeek",
@@ -115,7 +112,6 @@ AGGREGATOR_PROVIDERS: set[str] = {
LlmProviderNames.VERTEX_AI,
LlmProviderNames.AZURE,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
}
# Model family name mappings for display name generation

View File

@@ -290,17 +290,6 @@ class LitellmLLM(LLM):
):
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
# Bifrost: OpenAI-compatible proxy that expects model names in
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
# We route through LiteLLM's openai provider with the Bifrost base URL,
# and ensure /v1 is appended.
if model_provider == LlmProviderNames.BIFROST:
self._custom_llm_provider = "openai"
if self._api_base is not None:
base = self._api_base.rstrip("/")
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
model_kwargs["api_base"] = self._api_base
# This is needed for Ollama to do proper function calling
if model_provider == LlmProviderNames.OLLAMA_CHAT and api_base is not None:
model_kwargs["api_base"] = api_base
@@ -412,20 +401,14 @@ class LitellmLLM(LLM):
optional_kwargs: dict[str, Any] = {}
# Model name
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
model_provider = (
f"{self.config.model_provider}/responses"
if is_openai_model # Uses litellm's completions -> responses bridge
else self.config.model_provider
)
if is_bifrost:
# Bifrost expects model names in provider/model format
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
# so LiteLLM doesn't try to route based on the provider prefix.
model = self.config.deployment_name or self.config.model_name
else:
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
model = (
f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
)
# Tool choice
if is_claude_model and tool_choice == ToolChoiceOptions.REQUIRED:
@@ -500,11 +483,10 @@ class LitellmLLM(LLM):
if structured_response_format:
optional_kwargs["response_format"] = structured_response_format
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
if not (is_claude_model or is_ollama or is_mistral):
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
# However, this param breaks Anthropic and Mistral models,
# so it must be conditionally included unless the request is
# routed through Bifrost's OpenAI-compatible endpoint.
# so it must be conditionally included.
# Additionally, tool_choice is not supported by Ollama and causes warnings if included.
# See also, https://github.com/ollama/ollama/issues/11171
optional_kwargs["allowed_openai_params"] = ["tool_choice"]

View File

@@ -11,7 +11,6 @@ class LLMOverride(BaseModel):
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None
display_name: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

View File

@@ -13,8 +13,6 @@ LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
BIFROST_PROVIDER_NAME = "bifrost"
# Providers that use optional Bearer auth from custom_config
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,

View File

@@ -15,7 +15,6 @@ from onyx.llm.well_known_providers.auto_update_service import (
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
@@ -50,7 +49,6 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
}

View File

@@ -690,9 +690,9 @@
}
},
"node_modules/@dotenvx/dotenvx/node_modules/picomatch": {
"version": "4.0.4",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
"version": "4.0.3",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"license": "MIT",
"engines": {
"node": ">=12"
@@ -9537,9 +9537,9 @@
"license": "ISC"
},
"node_modules/picomatch": {
"version": "2.3.2",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
"version": "2.3.1",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz",
"integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==",
"license": "MIT",
"engines": {
"node": ">=8.6"
@@ -11118,9 +11118,9 @@
}
},
"node_modules/tinyglobby/node_modules/picomatch": {
"version": "4.0.4",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
"version": "4.0.3",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"engines": {

View File

@@ -44,12 +44,11 @@ def _check_ssrf_safety(endpoint_url: str) -> None:
"""Raise OnyxError if endpoint_url could be used for SSRF.
Delegates to validate_outbound_http_url with https_only=True.
Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field.
"""
try:
validate_outbound_http_url(endpoint_url, https_only=True)
except (SSRFException, ValueError) as e:
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e))
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
# ---------------------------------------------------------------------------
@@ -63,9 +62,6 @@ def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookRespo
name=hook.name,
hook_point=hook.hook_point,
endpoint_url=hook.endpoint_url,
api_key_masked=(
hook.api_key.get_value(apply_mask=True) if hook.api_key else None
),
fail_strategy=hook.fail_strategy,
timeout_seconds=hook.timeout_seconds,
is_active=hook.is_active,
@@ -142,11 +138,19 @@ def _validate_endpoint(
)
return HookValidateResponse(status=HookValidateStatus.passed)
except httpx.TimeoutException as exc:
# Any timeout (connect, read, or write) means the configured timeout_seconds
# is too low for this endpoint. Report as timeout so the UI directs the user
# to increase the timeout setting.
# ConnectTimeout: TCP handshake never completed → cannot_connect.
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
if isinstance(exc, httpx.ConnectTimeout):
logger.warning(
"Hook endpoint validation: connect timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
logger.warning(
"Hook endpoint validation: timeout for %s",
"Hook endpoint validation: read/write timeout for %s",
endpoint_url,
exc_info=exc,
)
@@ -216,8 +220,8 @@ def create_hook(
db_session: Session = Depends(get_session),
) -> HookResponse:
"""Create a new hook. The endpoint is validated before persisting — creation fails if
the endpoint cannot be reached or the api_key is invalid. Hooks are created active.
"""
the endpoint cannot be reached or the api_key is invalid. Hooks are created inactive;
use POST /{hook_id}/activate once ready to receive traffic."""
spec = get_hook_point_spec(req.hook_point)
api_key = req.api_key.get_secret_value() if req.api_key else None
validation = _validate_endpoint(
@@ -236,10 +240,9 @@ def create_hook(
api_key=api_key,
fail_strategy=req.fail_strategy or spec.default_fail_strategy,
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
is_active=True,
is_reachable=True,
creator_id=user.id,
)
hook.is_reachable = True
db_session.commit()
return _hook_to_response(hook, creator_email=user.email)

View File

@@ -57,8 +57,6 @@ from onyx.llm.well_known_providers.llm_provider_options import (
)
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import BifrostFinalModelResponse
from onyx.server.manage.llm.models import BifrostModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LitellmFinalModelResponse
from onyx.server.manage.llm.models import LitellmModelDetails
@@ -1424,26 +1422,11 @@ def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
cleaned_api_base = api_base.strip().rstrip("/")
url = f"{cleaned_api_base}/v1/models"
return _get_openai_compatible_models_response(
url=url,
source_name="LiteLLM proxy",
api_key=api_key,
)
def _get_openai_compatible_models_response(
url: str,
source_name: str,
api_key: str | None = None,
) -> dict:
"""Fetch model metadata from an OpenAI-compatible `/models` endpoint."""
headers = {
"Authorization": f"Bearer {api_key}",
"HTTP-Referer": "https://onyx.app",
"X-Title": "Onyx",
}
if not api_key:
headers.pop("Authorization")
try:
response = httpx.get(url, headers=headers, timeout=10.0)
@@ -1453,125 +1436,20 @@ def _get_openai_compatible_models_response(
if e.response.status_code == 401:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Authentication failed: invalid or missing API key for {source_name}.",
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
)
elif e.response.status_code == 404:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"{source_name} models endpoint not found at {url}. Please verify the API base URL.",
f"LiteLLM models endpoint not found at {url}. Please verify the API base URL.",
)
else:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch {source_name} models: {e}",
f"Failed to fetch LiteLLM models: {e}",
)
except httpx.RequestError as e:
logger.warning(
"Failed to fetch models from OpenAI-compatible endpoint",
extra={"source": source_name, "url": url, "error": str(e)},
exc_info=True,
)
except Exception as e:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch {source_name} models: {e}",
f"Failed to fetch LiteLLM models: {e}",
)
except ValueError as e:
logger.warning(
"Received invalid model response from OpenAI-compatible endpoint",
extra={"source": source_name, "url": url, "error": str(e)},
exc_info=True,
)
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch {source_name} models: {e}",
)
@admin_router.post("/bifrost/available-models")
def get_bifrost_available_models(
request: BifrostModelsRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[BifrostFinalModelResponse]:
"""Fetch available models from Bifrost gateway /v1/models endpoint."""
response_json = _get_bifrost_models_response(
api_base=request.api_base, api_key=request.api_key
)
models = response_json.get("data", [])
if not isinstance(models, list) or len(models) == 0:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your Bifrost endpoint",
)
results: list[BifrostFinalModelResponse] = []
for model in models:
try:
model_id = model.get("id", "")
model_name = model.get("name", model_id)
if not model_id:
continue
# Skip embedding models
if is_embedding_model(model_id):
continue
results.append(
BifrostFinalModelResponse(
name=model_id,
display_name=model_name,
max_input_tokens=model.get("context_length"),
supports_image_input=infer_vision_support(model_id),
supports_reasoning=is_reasoning_model(model_id, model_name),
)
)
except Exception as e:
logger.warning(
"Failed to parse Bifrost model entry",
extra={"error": str(e), "item": str(model)[:1000]},
)
if not results:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No compatible models found from Bifrost",
)
sorted_results = sorted(results, key=lambda m: m.name.lower())
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
for r in sorted_results
],
source_label="Bifrost",
)
return sorted_results
def _get_bifrost_models_response(api_base: str, api_key: str | None = None) -> dict:
"""Perform GET to Bifrost /v1/models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
# Ensure we hit /v1/models
if cleaned_api_base.endswith("/v1"):
url = f"{cleaned_api_base}/models"
else:
url = f"{cleaned_api_base}/v1/models"
return _get_openai_compatible_models_response(
url=url,
source_name="Bifrost",
api_key=api_key,
)

View File

@@ -449,18 +449,3 @@ class LitellmModelDetails(BaseModel):
class LitellmFinalModelResponse(BaseModel):
provider_name: str # Provider name (e.g. "openai")
model_name: str # Model ID (e.g. "gpt-4o")
# Bifrost dynamic models fetch
class BifrostModelsRequest(BaseModel):
api_base: str
api_key: str | None = None
provider_name: str | None = None # Optional: to save models to existing provider
class BifrostFinalModelResponse(BaseModel):
name: str # Model ID in provider/model format (e.g. "anthropic/claude-sonnet-4-6")
display_name: str # Human-readable name from Bifrost API
max_input_tokens: int | None
supports_image_input: bool
supports_reasoning: bool

View File

@@ -25,7 +25,6 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
LlmProviderNames.BEDROCK,
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
LlmProviderNames.BIFROST,
}
)
@@ -51,25 +50,6 @@ BEDROCK_VISION_MODELS = frozenset(
}
)
# Known Bifrost/OpenAI-compatible vision-capable model families where the
# source API does not expose this metadata directly.
BIFROST_VISION_MODEL_FAMILIES = frozenset(
{
"anthropic/claude-3",
"anthropic/claude-4",
"amazon/nova-pro",
"amazon/nova-lite",
"amazon/nova-premier",
"openai/gpt-4o",
"openai/gpt-4.1",
"google/gemini",
"meta-llama/llama-3.2",
"mistral/pixtral",
"qwen/qwen2.5-vl",
"qwen/qwen-vl",
}
)
def is_valid_bedrock_model(
model_id: str,
@@ -96,18 +76,11 @@ def is_valid_bedrock_model(
def infer_vision_support(model_id: str) -> bool:
"""Infer vision support from model ID when base model metadata unavailable.
Used for providers like Bedrock and Bifrost where vision support may
need to be inferred from vendor/model naming conventions.
Used for cross-region inference profiles when the base model isn't
available in the user's region.
"""
model_id_lower = model_id.lower()
if any(vision_model in model_id_lower for vision_model in BEDROCK_VISION_MODELS):
return True
normalized_model_id = model_id_lower.replace(".", "/")
return any(
vision_model in normalized_model_id
for vision_model in BIFROST_VISION_MODEL_FAMILIES
)
return any(vision_model in model_id_lower for vision_model in BEDROCK_VISION_MODELS)
def generate_bedrock_display_name(model_id: str) -> str:
@@ -349,7 +322,7 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None
- Ollama: "llama3:70b""Meta"
- Ollama: "qwen2.5:7b""Alibaba"
"""
if provider in (LlmProviderNames.OPENROUTER, LlmProviderNames.BIFROST):
if provider == LlmProviderNames.OPENROUTER:
# Format: "vendor/model-name" e.g., "anthropic/claude-3-5-sonnet"
if "/" in model_name:
vendor_key = model_name.split("/")[0].lower()

View File

@@ -1,207 +0,0 @@
"""Generic Celery task lifecycle Prometheus metrics.
Provides signal handlers that track task started/completed/failed counts,
active task gauge, task duration histograms, and retry/reject/revoke counts.
These fire for ALL tasks on the worker — no per-connector enrichment
(see indexing_task_metrics.py for that).
Usage in a worker app module:
from onyx.server.metrics.celery_task_metrics import (
on_celery_task_prerun,
on_celery_task_postrun,
on_celery_task_retry,
on_celery_task_revoked,
on_celery_task_rejected,
)
# Call from the worker's existing signal handlers
"""
import threading
import time
from celery import Task
from prometheus_client import Counter
from prometheus_client import Gauge
from prometheus_client import Histogram
from onyx.utils.logger import setup_logger
logger = setup_logger()
TASK_STARTED = Counter(
"onyx_celery_task_started_total",
"Total Celery tasks started",
["task_name", "queue"],
)
TASK_COMPLETED = Counter(
"onyx_celery_task_completed_total",
"Total Celery tasks completed",
["task_name", "queue", "outcome"],
)
TASK_DURATION = Histogram(
"onyx_celery_task_duration_seconds",
"Celery task execution duration in seconds",
["task_name", "queue"],
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
)
TASKS_ACTIVE = Gauge(
"onyx_celery_tasks_active",
"Currently executing Celery tasks",
["task_name", "queue"],
)
TASK_RETRIED = Counter(
"onyx_celery_task_retried_total",
"Total Celery tasks retried",
["task_name", "queue"],
)
TASK_REVOKED = Counter(
"onyx_celery_task_revoked_total",
"Total Celery tasks revoked (cancelled)",
["task_name"],
)
TASK_REJECTED = Counter(
"onyx_celery_task_rejected_total",
"Total Celery tasks rejected by worker",
["task_name"],
)
# task_id → (monotonic start time, metric labels)
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
# Lock protecting _task_start_times — prerun, postrun, and eviction may
# run concurrently on thread-pool workers.
_task_start_times_lock = threading.Lock()
# Entries older than this are evicted on each prerun to prevent unbounded
# growth when tasks are killed (SIGTERM, OOM) and postrun never fires.
_MAX_START_TIME_AGE_SECONDS = 3600 # 1 hour
def _evict_stale_start_times() -> None:
"""Remove _task_start_times entries older than _MAX_START_TIME_AGE_SECONDS.
Must be called while holding _task_start_times_lock.
"""
now = time.monotonic()
stale_ids = [
tid
for tid, (start, _labels) in _task_start_times.items()
if now - start > _MAX_START_TIME_AGE_SECONDS
]
for tid in stale_ids:
entry = _task_start_times.pop(tid, None)
if entry is not None:
_labels = entry[1]
# Decrement active gauge for evicted tasks — these tasks were
# started but never completed (killed, OOM, etc.).
active_gauge = TASKS_ACTIVE.labels(**_labels)
if active_gauge._value.get() > 0:
active_gauge.dec()
def _get_task_labels(task: Task) -> dict[str, str]:
"""Extract task_name and queue labels from a Celery Task instance."""
task_name = task.name or "unknown"
queue = "unknown"
try:
delivery_info = task.request.delivery_info
if delivery_info:
queue = delivery_info.get("routing_key") or "unknown"
except AttributeError:
pass
return {"task_name": task_name, "queue": queue}
def on_celery_task_prerun(
task_id: str | None,
task: Task | None,
) -> None:
"""Record task start. Call from the worker's task_prerun signal handler."""
if task is None or task_id is None:
return
try:
labels = _get_task_labels(task)
TASK_STARTED.labels(**labels).inc()
TASKS_ACTIVE.labels(**labels).inc()
with _task_start_times_lock:
_evict_stale_start_times()
_task_start_times[task_id] = (time.monotonic(), labels)
except Exception:
logger.debug("Failed to record celery task prerun metrics", exc_info=True)
def on_celery_task_postrun(
task_id: str | None,
task: Task | None,
state: str | None,
) -> None:
"""Record task completion. Call from the worker's task_postrun signal handler."""
if task is None or task_id is None:
return
try:
labels = _get_task_labels(task)
outcome = "success" if state == "SUCCESS" else "failure"
TASK_COMPLETED.labels(**labels, outcome=outcome).inc()
# Guard against going below 0 if postrun fires without a matching
# prerun (e.g. after a worker restart or stale entry eviction).
active_gauge = TASKS_ACTIVE.labels(**labels)
if active_gauge._value.get() > 0:
active_gauge.dec()
with _task_start_times_lock:
entry = _task_start_times.pop(task_id, None)
if entry is not None:
start_time, _stored_labels = entry
TASK_DURATION.labels(**labels).observe(time.monotonic() - start_time)
except Exception:
logger.debug("Failed to record celery task postrun metrics", exc_info=True)
def on_celery_task_retry(
_task_id: str | None,
task: Task | None,
) -> None:
"""Record task retry. Call from the worker's task_retry signal handler."""
if task is None:
return
try:
labels = _get_task_labels(task)
TASK_RETRIED.labels(**labels).inc()
except Exception:
logger.debug("Failed to record celery task retry metrics", exc_info=True)
def on_celery_task_revoked(
_task_id: str | None,
task_name: str | None = None,
) -> None:
"""Record task revocation. The revoked signal doesn't provide a Task
instance, only the task name via sender."""
if task_name is None:
return
try:
TASK_REVOKED.labels(task_name=task_name).inc()
except Exception:
logger.debug("Failed to record celery task revoked metrics", exc_info=True)
def on_celery_task_rejected(
_task_id: str | None,
task_name: str | None = None,
) -> None:
"""Record task rejection."""
if task_name is None:
return
try:
TASK_REJECTED.labels(task_name=task_name).inc()
except Exception:
logger.debug("Failed to record celery task rejected metrics", exc_info=True)

View File

@@ -1,594 +0,0 @@
"""Prometheus collectors for Celery queue depths and indexing pipeline state.
These collectors query Redis and Postgres at scrape time (the Collector pattern),
so metrics are always fresh when Prometheus scrapes /metrics. They run inside the
monitoring celery worker which already has Redis and DB access.
To avoid hammering Redis/Postgres on every 15s scrape, results are cached with
a configurable TTL (default 30s). This means metrics may be up to TTL seconds
stale, which is fine for monitoring dashboards.
"""
import json
import threading
import time
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
from prometheus_client.core import GaugeMetricFamily
from prometheus_client.registry import Collector
from redis import Redis
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.constants import OnyxCeleryQueues
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Default cache TTL in seconds. Scrapes hitting within this window return
# the previous result without re-querying Redis/Postgres.
_DEFAULT_CACHE_TTL = 30.0
_QUEUE_LABEL_MAP: dict[str, str] = {
OnyxCeleryQueues.PRIMARY: "primary",
OnyxCeleryQueues.DOCPROCESSING: "docprocessing",
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING: "docfetching",
OnyxCeleryQueues.VESPA_METADATA_SYNC: "vespa_metadata_sync",
OnyxCeleryQueues.CONNECTOR_DELETION: "connector_deletion",
OnyxCeleryQueues.CONNECTOR_PRUNING: "connector_pruning",
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC: "permissions_sync",
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC: "external_group_sync",
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT: "permissions_upsert",
OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING: "hierarchy_fetching",
OnyxCeleryQueues.LLM_MODEL_UPDATE: "llm_model_update",
OnyxCeleryQueues.CHECKPOINT_CLEANUP: "checkpoint_cleanup",
OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP: "index_attempt_cleanup",
OnyxCeleryQueues.CSV_GENERATION: "csv_generation",
OnyxCeleryQueues.USER_FILE_PROCESSING: "user_file_processing",
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC: "user_file_project_sync",
OnyxCeleryQueues.USER_FILE_DELETE: "user_file_delete",
OnyxCeleryQueues.MONITORING: "monitoring",
OnyxCeleryQueues.SANDBOX: "sandbox",
OnyxCeleryQueues.OPENSEARCH_MIGRATION: "opensearch_migration",
}
# Queues where prefetched (unacked) task counts are meaningful
_UNACKED_QUEUES: list[str] = [
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
OnyxCeleryQueues.DOCPROCESSING,
]
class _CachedCollector(Collector):
"""Base collector with TTL-based caching.
Subclasses implement ``_collect_fresh()`` to query the actual data source.
The base ``collect()`` returns cached results if the TTL hasn't expired,
avoiding repeated queries when Prometheus scrapes frequently.
"""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
self._cache_ttl = cache_ttl
self._cached_result: list[GaugeMetricFamily] | None = None
self._last_collect_time: float = 0.0
self._lock = threading.Lock()
def collect(self) -> list[GaugeMetricFamily]:
with self._lock:
now = time.monotonic()
if (
now - self._last_collect_time < self._cache_ttl
and self._cached_result is not None
):
return self._cached_result
try:
result = self._collect_fresh()
self._cached_result = result
self._last_collect_time = now
return result
except Exception:
logger.exception(f"Error in {type(self).__name__}.collect()")
# Return stale cache on error rather than nothing — avoids
# metrics disappearing during transient failures.
return self._cached_result if self._cached_result is not None else []
def _collect_fresh(self) -> list[GaugeMetricFamily]:
raise NotImplementedError
def describe(self) -> list[GaugeMetricFamily]:
return []
class QueueDepthCollector(_CachedCollector):
"""Reads Celery queue lengths from the broker Redis on each scrape.
Uses a Redis client factory (callable) rather than a stored client
reference so the connection is always fresh from Celery's pool.
"""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._get_redis: Callable[[], Redis] | None = None
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._get_redis is None:
return []
redis_client = self._get_redis()
depth = GaugeMetricFamily(
"onyx_queue_depth",
"Number of tasks waiting in Celery queue",
labels=["queue"],
)
unacked = GaugeMetricFamily(
"onyx_queue_unacked",
"Number of prefetched (unacked) tasks for queue",
labels=["queue"],
)
queue_age = GaugeMetricFamily(
"onyx_queue_oldest_task_age_seconds",
"Age of the oldest task in the queue (seconds since enqueue)",
labels=["queue"],
)
now = time.time()
for queue_name, label in _QUEUE_LABEL_MAP.items():
length = celery_get_queue_length(queue_name, redis_client)
depth.add_metric([label], length)
# Peek at the oldest message to get its age
if length > 0:
age = self._get_oldest_message_age(redis_client, queue_name, now)
if age is not None:
queue_age.add_metric([label], age)
for queue_name in _UNACKED_QUEUES:
label = _QUEUE_LABEL_MAP[queue_name]
task_ids = celery_get_unacked_task_ids(queue_name, redis_client)
unacked.add_metric([label], len(task_ids))
return [depth, unacked, queue_age]
@staticmethod
def _get_oldest_message_age(
redis_client: Redis, queue_name: str, now: float
) -> float | None:
"""Peek at the oldest (tail) message in a Redis list queue
and extract its timestamp to compute age.
Note: If the Celery message contains neither ``properties.timestamp``
nor ``headers.timestamp``, no age metric is emitted for this queue.
This can happen with custom task producers or non-standard Celery
protocol versions. The metric will simply be absent rather than
inaccurate, which is the safest behavior for alerting.
"""
try:
raw: bytes | str | None = redis_client.lindex(queue_name, -1) # type: ignore[assignment]
if raw is None:
return None
msg = json.loads(raw)
# Check for ETA tasks first — they are intentionally delayed,
# so reporting their queue age would be misleading.
headers = msg.get("headers", {})
if headers.get("eta") is not None:
return None
# Celery v2 protocol: timestamp in properties
props = msg.get("properties", {})
ts = props.get("timestamp")
if ts is not None:
return now - float(ts)
# Fallback: some Celery configurations place the timestamp in
# headers instead of properties.
ts = headers.get("timestamp")
if ts is not None:
return now - float(ts)
except Exception:
pass
return None
class IndexAttemptCollector(_CachedCollector):
"""Queries Postgres for index attempt state on each scrape."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._configured: bool = False
self._terminal_statuses: list = []
def configure(self) -> None:
"""Call once DB engine is initialized."""
from onyx.db.enums import IndexingStatus
self._terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
self._configured = True
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if not self._configured:
return []
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.index_attempt import get_active_index_attempts_for_metrics
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
attempts_gauge = GaugeMetricFamily(
"onyx_index_attempts_active",
"Number of non-terminal index attempts",
labels=[
"status",
"source",
"tenant_id",
"connector_name",
"cc_pair_id",
],
)
tenant_ids = get_all_tenant_ids()
for tid in tenant_ids:
# Defensive guard — get_all_tenant_ids() should never yield None,
# but we guard here for API stability in case the contract changes.
if tid is None:
continue
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
try:
with get_session_with_current_tenant() as session:
rows = get_active_index_attempts_for_metrics(session)
for status, source, cc_id, cc_name, count in rows:
name_val = cc_name or f"cc_pair_{cc_id}"
attempts_gauge.add_metric(
[
status.value,
source.value,
tid,
name_val,
str(cc_id),
],
count,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return [attempts_gauge]
class ConnectorHealthCollector(_CachedCollector):
"""Queries Postgres for connector health state on each scrape."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._configured: bool = False
def configure(self) -> None:
"""Call once DB engine is initialized."""
self._configured = True
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if not self._configured:
return []
from onyx.db.connector_credential_pair import (
get_connector_health_for_metrics,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.index_attempt import get_docs_indexed_by_cc_pair
from onyx.db.index_attempt import get_failed_attempt_counts_by_cc_pair
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
staleness_gauge = GaugeMetricFamily(
"onyx_connector_last_success_age_seconds",
"Seconds since last successful index for this connector",
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
)
error_state_gauge = GaugeMetricFamily(
"onyx_connector_in_error_state",
"Whether the connector is in a repeated error state (1=yes, 0=no)",
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
)
by_status_gauge = GaugeMetricFamily(
"onyx_connectors_by_status",
"Number of connectors grouped by status",
labels=["tenant_id", "status"],
)
error_total_gauge = GaugeMetricFamily(
"onyx_connectors_in_error_total",
"Total number of connectors in repeated error state",
labels=["tenant_id"],
)
per_connector_labels = [
"tenant_id",
"source",
"cc_pair_id",
"connector_name",
]
docs_success_gauge = GaugeMetricFamily(
"onyx_connector_docs_indexed",
"Total new documents indexed (90-day rolling sum) per connector",
labels=per_connector_labels,
)
docs_error_gauge = GaugeMetricFamily(
"onyx_connector_error_count",
"Total number of failed index attempts per connector",
labels=per_connector_labels,
)
now = datetime.now(tz=timezone.utc)
tenant_ids = get_all_tenant_ids()
for tid in tenant_ids:
# Defensive guard — get_all_tenant_ids() should never yield None,
# but we guard here for API stability in case the contract changes.
if tid is None:
continue
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
try:
with get_session_with_current_tenant() as session:
pairs = get_connector_health_for_metrics(session)
error_counts_by_cc = get_failed_attempt_counts_by_cc_pair(session)
docs_by_cc = get_docs_indexed_by_cc_pair(session)
status_counts: dict[str, int] = {}
error_count = 0
for (
cc_id,
status,
in_error,
last_success,
cc_name,
source,
) in pairs:
cc_id_str = str(cc_id)
source_val = source.value
name_val = cc_name or f"cc_pair_{cc_id}"
label_vals = [tid, source_val, cc_id_str, name_val]
if last_success is not None:
# Both `now` and `last_success` are timezone-aware
# (the DB column uses DateTime(timezone=True)),
# so subtraction is safe.
age = (now - last_success).total_seconds()
staleness_gauge.add_metric(label_vals, age)
error_state_gauge.add_metric(
label_vals,
1.0 if in_error else 0.0,
)
if in_error:
error_count += 1
docs_success_gauge.add_metric(
label_vals,
docs_by_cc.get(cc_id, 0),
)
docs_error_gauge.add_metric(
label_vals,
error_counts_by_cc.get(cc_id, 0),
)
status_val = status.value
status_counts[status_val] = status_counts.get(status_val, 0) + 1
for status_val, count in status_counts.items():
by_status_gauge.add_metric([tid, status_val], count)
error_total_gauge.add_metric([tid], error_count)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return [
staleness_gauge,
error_state_gauge,
by_status_gauge,
error_total_gauge,
docs_success_gauge,
docs_error_gauge,
]
class RedisHealthCollector(_CachedCollector):
"""Collects Redis server health metrics (memory, clients, etc.)."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._get_redis: Callable[[], Redis] | None = None
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._get_redis is None:
return []
redis_client = self._get_redis()
memory_used = GaugeMetricFamily(
"onyx_redis_memory_used_bytes",
"Redis used memory in bytes",
)
memory_peak = GaugeMetricFamily(
"onyx_redis_memory_peak_bytes",
"Redis peak used memory in bytes",
)
memory_frag = GaugeMetricFamily(
"onyx_redis_memory_fragmentation_ratio",
"Redis memory fragmentation ratio (>1.5 indicates fragmentation)",
)
connected_clients = GaugeMetricFamily(
"onyx_redis_connected_clients",
"Number of connected Redis clients",
)
try:
mem_info: dict = redis_client.info("memory") # type: ignore[assignment]
memory_used.add_metric([], mem_info.get("used_memory", 0))
memory_peak.add_metric([], mem_info.get("used_memory_peak", 0))
frag = mem_info.get("mem_fragmentation_ratio")
if frag is not None:
memory_frag.add_metric([], frag)
client_info: dict = redis_client.info("clients") # type: ignore[assignment]
connected_clients.add_metric([], client_info.get("connected_clients", 0))
except Exception:
logger.debug("Failed to collect Redis health metrics", exc_info=True)
return [memory_used, memory_peak, memory_frag, connected_clients]
class WorkerHeartbeatMonitor:
"""Monitors Celery worker health via the event stream.
Subscribes to ``worker-heartbeat``, ``worker-online``, and
``worker-offline`` events via a single persistent connection.
Runs in a daemon thread started once during worker setup.
"""
# Consider a worker down if no heartbeat received for this long.
_HEARTBEAT_TIMEOUT_SECONDS = 120.0
def __init__(self, celery_app: Any) -> None:
self._app = celery_app
self._worker_last_seen: dict[str, float] = {}
self._lock = threading.Lock()
self._running = False
self._thread: threading.Thread | None = None
def start(self) -> None:
"""Start the background event listener thread.
Safe to call multiple times — only starts one thread.
"""
if self._thread is not None and self._thread.is_alive():
return
self._running = True
self._thread = threading.Thread(target=self._listen, daemon=True)
self._thread.start()
logger.info("WorkerHeartbeatMonitor started")
def stop(self) -> None:
self._running = False
def _listen(self) -> None:
"""Background loop: connect to event stream and process heartbeats."""
while self._running:
try:
with self._app.connection() as conn:
recv = self._app.events.Receiver(
conn,
handlers={
"worker-heartbeat": self._on_heartbeat,
"worker-online": self._on_heartbeat,
"worker-offline": self._on_offline,
},
)
recv.capture(
limit=None, timeout=self._HEARTBEAT_TIMEOUT_SECONDS, wakeup=True
)
except Exception:
if self._running:
logger.debug(
"Heartbeat listener disconnected, reconnecting in 5s",
exc_info=True,
)
time.sleep(5.0)
else:
# capture() returned normally (timeout with no events); reconnect
if self._running:
logger.debug("Heartbeat capture timed out, reconnecting")
time.sleep(5.0)
def _on_heartbeat(self, event: dict[str, Any]) -> None:
hostname = event.get("hostname")
if hostname:
with self._lock:
self._worker_last_seen[hostname] = time.monotonic()
def _on_offline(self, event: dict[str, Any]) -> None:
hostname = event.get("hostname")
if hostname:
with self._lock:
self._worker_last_seen.pop(hostname, None)
def get_worker_status(self) -> dict[str, bool]:
"""Return {hostname: is_alive} for all known workers.
Thread-safe. Called by WorkerHealthCollector on each scrape.
Also prunes workers that have been dead longer than 2x the
heartbeat timeout to prevent unbounded growth.
"""
now = time.monotonic()
prune_threshold = self._HEARTBEAT_TIMEOUT_SECONDS * 2
with self._lock:
# Prune workers that have been gone for 2x the timeout
stale = [
h
for h, ts in self._worker_last_seen.items()
if (now - ts) > prune_threshold
]
for h in stale:
del self._worker_last_seen[h]
result: dict[str, bool] = {}
for hostname, last_seen in self._worker_last_seen.items():
alive = (now - last_seen) < self._HEARTBEAT_TIMEOUT_SECONDS
result[hostname] = alive
return result
class WorkerHealthCollector(_CachedCollector):
"""Collects Celery worker health from the heartbeat monitor.
Reads worker status from ``WorkerHeartbeatMonitor`` which listens
to the Celery event stream via a single persistent connection.
"""
def __init__(self, cache_ttl: float = 30.0) -> None:
super().__init__(cache_ttl)
self._monitor: WorkerHeartbeatMonitor | None = None
def set_monitor(self, monitor: WorkerHeartbeatMonitor) -> None:
"""Set the heartbeat monitor instance."""
self._monitor = monitor
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._monitor is None:
return []
active_workers = GaugeMetricFamily(
"onyx_celery_active_worker_count",
"Number of active Celery workers with recent heartbeats",
)
worker_up = GaugeMetricFamily(
"onyx_celery_worker_up",
"Whether a specific Celery worker is alive (1=up, 0=down)",
labels=["worker"],
)
try:
status = self._monitor.get_worker_status()
alive_count = sum(1 for alive in status.values() if alive)
active_workers.add_metric([], alive_count)
for hostname in sorted(status):
# Use short name (before @) for single-host deployments,
# full hostname when multiple hosts share a worker type.
label = hostname.split("@")[0]
worker_up.add_metric([label], 1 if status[hostname] else 0)
except Exception:
logger.debug("Failed to collect worker health metrics", exc_info=True)
return [active_workers, worker_up]

View File

@@ -1,124 +0,0 @@
"""Setup function for indexing pipeline Prometheus collectors.
Called once by the monitoring celery worker after Redis and DB are ready.
"""
from collections.abc import Callable
from typing import Any
from celery import Celery
from prometheus_client.registry import REGISTRY
from redis import Redis
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
from onyx.server.metrics.indexing_pipeline import WorkerHeartbeatMonitor
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Module-level singletons — these are lightweight objects (no connections or DB
# state) until configure() / set_redis_factory() is called. Keeping them at
# module level ensures they survive the lifetime of the worker process and are
# only registered with the Prometheus registry once.
_queue_collector = QueueDepthCollector()
_attempt_collector = IndexAttemptCollector()
_connector_collector = ConnectorHealthCollector()
_redis_health_collector = RedisHealthCollector()
_worker_health_collector = WorkerHealthCollector()
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
"""Create a factory that returns a cached broker Redis client.
Reuses a single connection across scrapes to avoid leaking connections.
Reconnects automatically if the cached connection becomes stale.
"""
_cached_client: list[Redis | None] = [None]
# Keep a reference to the Kombu Connection so we can close it on
# reconnect (the raw Redis client outlives the Kombu wrapper).
_cached_kombu_conn: list[Any] = [None]
def _close_client(client: Redis) -> None:
"""Best-effort close of a Redis client."""
try:
client.close()
except Exception:
logger.debug("Failed to close stale Redis client", exc_info=True)
def _close_kombu_conn() -> None:
"""Best-effort close of the cached Kombu Connection."""
conn = _cached_kombu_conn[0]
if conn is not None:
try:
conn.close()
except Exception:
logger.debug("Failed to close Kombu connection", exc_info=True)
_cached_kombu_conn[0] = None
def _get_broker_redis() -> Redis:
client = _cached_client[0]
if client is not None:
try:
client.ping()
return client
except Exception:
logger.debug("Cached Redis client stale, reconnecting")
_close_client(client)
_cached_client[0] = None
_close_kombu_conn()
# Get a fresh Redis client from the broker connection.
# We hold this client long-term (cached above) rather than using a
# context manager, because we need it to persist across scrapes.
# The caching logic above ensures we only ever hold one connection,
# and we close it explicitly on reconnect.
conn = celery_app.broker_connection()
# kombu's Channel exposes .client at runtime (the underlying Redis
# client) but the type stubs don't declare it.
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
_cached_client[0] = new_client
_cached_kombu_conn[0] = conn
return new_client
return _get_broker_redis
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
"""Register all indexing pipeline collectors with the default registry.
Args:
celery_app: The Celery application instance. Used to obtain a fresh
broker Redis client on each scrape for queue depth metrics.
"""
redis_factory = _make_broker_redis_factory(celery_app)
_queue_collector.set_redis_factory(redis_factory)
_redis_health_collector.set_redis_factory(redis_factory)
# Start the heartbeat monitor daemon thread — uses a single persistent
# connection to receive worker-heartbeat events.
# Module-level singleton prevents duplicate threads on re-entry.
global _heartbeat_monitor
if _heartbeat_monitor is None:
_heartbeat_monitor = WorkerHeartbeatMonitor(celery_app)
_heartbeat_monitor.start()
_worker_health_collector.set_monitor(_heartbeat_monitor)
_attempt_collector.configure()
_connector_collector.configure()
for collector in (
_queue_collector,
_attempt_collector,
_connector_collector,
_redis_health_collector,
_worker_health_collector,
):
try:
REGISTRY.register(collector)
except ValueError:
logger.debug("Collector already registered: %s", type(collector).__name__)

View File

@@ -1,253 +0,0 @@
"""Per-connector Prometheus metrics for indexing tasks.
Enriches the two primary indexing tasks (docfetching_proxy_task and
docprocessing_task) with connector-level labels: source, tenant_id,
and cc_pair_id.
Note: connector_name is intentionally excluded from push-based per-task
counters because it is a user-defined free-form string that can create
unbounded cardinality. The pull-based collectors on the monitoring worker
(see indexing_pipeline.py) include connector_name since they have bounded
cardinality (one series per connector, not per task execution).
Uses an in-memory cache for cc_pair_id → (source, name) lookups.
Connectors never change source type, and names change rarely, so the
cache is safe to hold for the worker's lifetime.
Usage in a worker app module:
from onyx.server.metrics.indexing_task_metrics import (
on_indexing_task_prerun,
on_indexing_task_postrun,
)
"""
import threading
import time
from dataclasses import dataclass
from celery import Task
from prometheus_client import Counter
from prometheus_client import Histogram
from onyx.configs.constants import OnyxCeleryTask
from onyx.server.metrics.celery_task_metrics import _MAX_START_TIME_AGE_SECONDS
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@dataclass(frozen=True)
class ConnectorInfo:
"""Cached connector metadata for metric labels."""
source: str
name: str
_UNKNOWN_CONNECTOR = ConnectorInfo(source="unknown", name="unknown")
# (tenant_id, cc_pair_id) → ConnectorInfo (populated on first encounter).
# Keyed by tenant to avoid cross-tenant cache poisoning in multi-tenant
# deployments where different tenants can share the same cc_pair_id value.
_connector_cache: dict[tuple[str, int], ConnectorInfo] = {}
# Lock protecting _connector_cache — multiple thread-pool workers may
# resolve connectors concurrently.
_connector_cache_lock = threading.Lock()
# Only enrich these task types with per-connector labels
_INDEXING_TASK_NAMES: frozenset[str] = frozenset(
{
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
OnyxCeleryTask.DOCPROCESSING_TASK,
}
)
# connector_name is intentionally excluded — see module docstring.
INDEXING_TASK_STARTED = Counter(
"onyx_indexing_task_started_total",
"Indexing tasks started per connector",
["task_name", "source", "tenant_id", "cc_pair_id"],
)
INDEXING_TASK_COMPLETED = Counter(
"onyx_indexing_task_completed_total",
"Indexing tasks completed per connector",
[
"task_name",
"source",
"tenant_id",
"cc_pair_id",
"outcome",
],
)
INDEXING_TASK_DURATION = Histogram(
"onyx_indexing_task_duration_seconds",
"Indexing task duration by connector type",
["task_name", "source", "tenant_id"],
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
)
# task_id → monotonic start time (for indexing tasks only)
_indexing_start_times: dict[str, float] = {}
# Lock protecting _indexing_start_times — prerun, postrun, and eviction may
# run concurrently on thread-pool workers.
_indexing_start_times_lock = threading.Lock()
def _evict_stale_start_times() -> None:
"""Remove _indexing_start_times entries older than _MAX_START_TIME_AGE_SECONDS.
Must be called while holding _indexing_start_times_lock.
"""
now = time.monotonic()
stale_ids = [
tid
for tid, start in _indexing_start_times.items()
if now - start > _MAX_START_TIME_AGE_SECONDS
]
for tid in stale_ids:
_indexing_start_times.pop(tid, None)
def _resolve_connector(cc_pair_id: int) -> ConnectorInfo:
"""Resolve cc_pair_id to ConnectorInfo, using cache when possible.
On cache miss, does a single DB query with eager connector load.
On any failure, returns _UNKNOWN_CONNECTOR without caching, so that
subsequent calls can retry the lookup once the DB is available.
Note on tenant_id source: we read CURRENT_TENANT_ID_CONTEXTVAR for the
cache key. The Celery tenant-aware middleware sets this contextvar before
task execution, and it always matches kwargs["tenant_id"] (which is set
at task dispatch time). They are guaranteed to agree for a given task
execution context.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get("") or ""
cache_key = (tenant_id, cc_pair_id)
with _connector_cache_lock:
cached = _connector_cache.get(cache_key)
if cached is not None:
return cached
try:
from onyx.db.connector_credential_pair import (
get_connector_credential_pair_from_id,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session,
cc_pair_id,
eager_load_connector=True,
)
if cc_pair is None:
# DB lookup succeeded but cc_pair doesn't exist — don't cache,
# it may appear later (race with connector creation).
return _UNKNOWN_CONNECTOR
info = ConnectorInfo(
source=cc_pair.connector.source.value,
name=cc_pair.name,
)
with _connector_cache_lock:
_connector_cache[cache_key] = info
return info
except Exception:
logger.debug(
f"Failed to resolve connector info for cc_pair_id={cc_pair_id}",
exc_info=True,
)
return _UNKNOWN_CONNECTOR
def on_indexing_task_prerun(
task_id: str | None,
task: Task | None,
kwargs: dict | None,
) -> None:
"""Record per-connector metrics at task start.
Only fires for tasks in _INDEXING_TASK_NAMES. Silently returns for
all other tasks.
"""
if task is None or task_id is None or kwargs is None:
return
task_name = task.name or ""
if task_name not in _INDEXING_TASK_NAMES:
return
try:
cc_pair_id = kwargs.get("cc_pair_id")
tenant_id = str(kwargs.get("tenant_id", "unknown"))
if cc_pair_id is None:
return
info = _resolve_connector(cc_pair_id)
INDEXING_TASK_STARTED.labels(
task_name=task_name,
source=info.source,
tenant_id=tenant_id,
cc_pair_id=str(cc_pair_id),
).inc()
with _indexing_start_times_lock:
_evict_stale_start_times()
_indexing_start_times[task_id] = time.monotonic()
except Exception:
logger.debug("Failed to record indexing task prerun metrics", exc_info=True)
def on_indexing_task_postrun(
task_id: str | None,
task: Task | None,
kwargs: dict | None,
state: str | None,
) -> None:
"""Record per-connector completion metrics.
Only fires for tasks in _INDEXING_TASK_NAMES.
"""
if task is None or task_id is None or kwargs is None:
return
task_name = task.name or ""
if task_name not in _INDEXING_TASK_NAMES:
return
try:
cc_pair_id = kwargs.get("cc_pair_id")
tenant_id = str(kwargs.get("tenant_id", "unknown"))
if cc_pair_id is None:
return
info = _resolve_connector(cc_pair_id)
outcome = "success" if state == "SUCCESS" else "failure"
INDEXING_TASK_COMPLETED.labels(
task_name=task_name,
source=info.source,
tenant_id=tenant_id,
cc_pair_id=str(cc_pair_id),
outcome=outcome,
).inc()
with _indexing_start_times_lock:
start = _indexing_start_times.pop(task_id, None)
if start is not None:
INDEXING_TASK_DURATION.labels(
task_name=task_name,
source=info.source,
tenant_id=tenant_id,
).observe(time.monotonic() - start)
except Exception:
logger.debug("Failed to record indexing task postrun metrics", exc_info=True)

View File

@@ -1,89 +0,0 @@
"""Standalone Prometheus metrics HTTP server for non-API processes.
The FastAPI API server already exposes /metrics via prometheus-fastapi-instrumentator.
Celery workers and other background processes use this module to expose their
own /metrics endpoint on a configurable port.
Usage:
from onyx.server.metrics.metrics_server import start_metrics_server
start_metrics_server("monitoring") # reads port from env or uses default
"""
import os
import threading
from prometheus_client import start_http_server
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Default ports for worker types that serve custom Prometheus metrics.
# Only add entries here when a worker actually registers collectors.
# In k8s each worker type runs in its own pod, so PROMETHEUS_METRICS_PORT
# env var can override.
_DEFAULT_PORTS: dict[str, int] = {
"monitoring": 9096,
"docfetching": 9092,
"docprocessing": 9093,
}
_server_started = False
_server_lock = threading.Lock()
def start_metrics_server(worker_type: str) -> int | None:
"""Start a Prometheus metrics HTTP server in a background thread.
Returns the port if started, None if disabled or already started.
Port resolution order:
1. PROMETHEUS_METRICS_PORT env var (explicit override)
2. Default port for the worker type
3. If worker type is unknown and no env var, skip
Set PROMETHEUS_METRICS_ENABLED=false to disable.
"""
global _server_started
with _server_lock:
if _server_started:
logger.debug(f"Metrics server already started for {worker_type}")
return None
enabled = os.environ.get("PROMETHEUS_METRICS_ENABLED", "true").lower()
if enabled in ("false", "0", "no"):
logger.info(f"Prometheus metrics server disabled for {worker_type}")
return None
port_str = os.environ.get("PROMETHEUS_METRICS_PORT")
if port_str:
try:
port = int(port_str)
except ValueError:
logger.warning(
f"Invalid PROMETHEUS_METRICS_PORT '{port_str}' for {worker_type}, "
"must be a numeric port. Skipping metrics server."
)
return None
elif worker_type in _DEFAULT_PORTS:
port = _DEFAULT_PORTS[worker_type]
else:
logger.info(
f"No default metrics port for worker type '{worker_type}' "
"and PROMETHEUS_METRICS_PORT not set. Skipping metrics server."
)
return None
try:
start_http_server(port)
_server_started = True
logger.info(
f"Prometheus metrics server started on :{port} for {worker_type}"
)
return port
except OSError as e:
logger.warning(
f"Failed to start metrics server on :{port} for {worker_type}: {e}"
)
return None

View File

@@ -1,106 +0,0 @@
"""Prometheus metrics for OpenSearch search latency and throughput.
Tracks client-side round-trip latency, server-side execution time (from
OpenSearch's ``took`` field), total search count, and in-flight concurrency.
"""
import logging
from collections.abc import Generator
from contextlib import contextmanager
from prometheus_client import Counter
from prometheus_client import Gauge
from prometheus_client import Histogram
from onyx.document_index.opensearch.constants import OpenSearchSearchType
logger = logging.getLogger(__name__)
_SEARCH_LATENCY_BUCKETS = (
0.005,
0.01,
0.025,
0.05,
0.1,
0.25,
0.5,
1.0,
2.5,
5.0,
10.0,
25.0,
)
_client_duration = Histogram(
"onyx_opensearch_search_client_duration_seconds",
"Client-side end-to-end latency of OpenSearch search calls",
["search_type"],
buckets=_SEARCH_LATENCY_BUCKETS,
)
_server_duration = Histogram(
"onyx_opensearch_search_server_duration_seconds",
"Server-side execution time reported by OpenSearch (took field)",
["search_type"],
buckets=_SEARCH_LATENCY_BUCKETS,
)
_search_total = Counter(
"onyx_opensearch_search_total",
"Total number of search requests sent to OpenSearch",
["search_type"],
)
_searches_in_progress = Gauge(
"onyx_opensearch_searches_in_progress",
"Number of OpenSearch searches currently in-flight",
["search_type"],
)
def observe_opensearch_search(
search_type: OpenSearchSearchType,
client_duration_s: float,
server_took_ms: int | None,
) -> None:
"""Records latency and throughput metrics for a completed OpenSearch search.
Args:
search_type: The type of search.
client_duration_s: Wall-clock duration measured on the client side, in
seconds.
server_took_ms: The ``took`` value from the OpenSearch response, in
milliseconds. May be ``None`` if the response did not include it.
"""
try:
label = search_type.value
_search_total.labels(search_type=label).inc()
_client_duration.labels(search_type=label).observe(client_duration_s)
if server_took_ms is not None:
_server_duration.labels(search_type=label).observe(server_took_ms / 1000.0)
except Exception:
logger.warning("Failed to record OpenSearch search metrics.", exc_info=True)
@contextmanager
def track_opensearch_search_in_progress(
search_type: OpenSearchSearchType,
) -> Generator[None, None, None]:
"""Context manager that tracks in-flight OpenSearch searches via a Gauge."""
incremented = False
label = search_type.value
try:
_searches_in_progress.labels(search_type=label).inc()
incremented = True
except Exception:
logger.warning("Failed to increment in-progress search gauge.", exc_info=True)
try:
yield
finally:
if incremented:
try:
_searches_in_progress.labels(search_type=label).dec()
except Exception:
logger.warning(
"Failed to decrement in-progress search gauge.", exc_info=True
)

View File

@@ -41,21 +41,6 @@ class MessageResponseIDInfo(BaseModel):
reserved_assistant_message_id: int
class ModelResponseSlot(BaseModel):
"""Pairs a reserved assistant message ID with its model display name."""
message_id: int
model_name: str
class MultiModelMessageResponseIDInfo(BaseModel):
"""Sent at the start of a multi-model streaming response.
Contains the user message ID and one slot per model being run in parallel."""
user_message_id: int | None
responses: list[ModelResponseSlot]
class SourceTag(Tag):
source: DocumentSource
@@ -101,9 +86,6 @@ class SendMessageRequest(BaseModel):
message: str
llm_override: LLMOverride | None = None
# For multi-model mode: up to 3 LLM overrides to run in parallel.
# When provided with >1 entry, triggers multi-model streaming.
llm_overrides: list[LLMOverride] | None = None
# Test-only override for deterministic LiteLLM mock responses.
mock_llm_response: str | None = None
@@ -229,8 +211,6 @@ class ChatMessageDetail(BaseModel):
error: str | None = None
current_feedback: str | None = None # "like" | "dislike" | null
processing_duration_seconds: float | None = None
preferred_response_id: int | None = None
model_display_name: str | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
@@ -238,11 +218,6 @@ class ChatMessageDetail(BaseModel):
return initial_dict
class SetPreferredResponseRequest(BaseModel):
user_message_id: int
preferred_response_id: int
class ChatSessionDetailResponse(BaseModel):
chat_session_id: UUID
description: str | None

View File

@@ -8,5 +8,3 @@ class Placement(BaseModel):
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
model_index: int | None = None

View File

@@ -17,7 +17,6 @@ from onyx.db.models import User
from onyx.db.notification import dismiss_all_notifications
from onyx.db.notification import get_notifications
from onyx.db.notification import update_notification_last_shown
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.features.build.utils import is_onyx_craft_enabled
@@ -81,7 +80,6 @@ def fetch_settings(
needs_reindexing=needs_reindexing,
onyx_craft_enabled=onyx_craft_enabled_for_user,
vector_db_enabled=not DISABLE_VECTOR_DB,
hooks_enabled=HOOKS_AVAILABLE,
version=onyx_version,
)

View File

@@ -104,7 +104,5 @@ class UserSettings(Settings):
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
# document sets are unavailable.
vector_db_enabled: bool = True
# True when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
hooks_enabled: bool = False
# Application version, read from the ONYX_VERSION env var at startup.
version: str | None = None

View File

@@ -736,7 +736,7 @@ if __name__ == "__main__":
llm.config.model_name, llm.config.model_provider
)
persona = get_default_behavior_persona(db_session, eager_load_for_tools=True)
persona = get_default_behavior_persona(db_session)
if persona is None:
raise ValueError("No default persona found")

View File

@@ -9,7 +9,6 @@ from onyx.chat.emitter import Emitter
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import PersonaSearchInfo
from onyx.db.enums import MCPAuthenticationPerformer
from onyx.db.enums import MCPAuthenticationType
from onyx.db.mcp import get_all_mcp_tools_for_server
@@ -125,12 +124,7 @@ def construct_tools(
) -> dict[int, list[Tool]]:
"""Constructs tools based on persona configuration and available APIs.
Will simply skip tools that are not allowed/available.
Callers must supply a persona with ``tools``, ``document_sets``,
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
to avoid lazy SQL queries after the session may have been flushed."""
Will simply skip tools that are not allowed/available."""
tool_dict: dict[int, list[Tool]] = {}
# Log which tools are attached to the persona for debugging
@@ -149,28 +143,6 @@ def construct_tools(
# This flow is for search so we do not get all indices.
document_index = get_default_document_index(search_settings, None, db_session)
def _build_search_tool(tool_id: int, config: SearchToolConfig) -> SearchTool:
persona_search_info = PersonaSearchInfo(
document_set_names=[ds.name for ds in persona.document_sets],
search_start_date=persona.search_start_date,
attached_document_ids=[doc.id for doc in persona.attached_documents],
hierarchy_node_ids=[node.id for node in persona.hierarchy_nodes],
)
return SearchTool(
tool_id=tool_id,
emitter=emitter,
user=user,
persona_search_info=persona_search_info,
llm=llm,
document_index=document_index,
user_selected_filters=config.user_selected_filters,
project_id_filter=config.project_id_filter,
persona_id_filter=config.persona_id_filter,
bypass_acl=config.bypass_acl,
slack_context=config.slack_context,
enable_slack_search=config.enable_slack_search,
)
added_search_tool = False
for db_tool_model in persona.tools:
# If allowed_tool_ids is specified, skip tools not in the allowed list
@@ -204,9 +176,22 @@ def construct_tools(
if not search_tool_config:
search_tool_config = SearchToolConfig()
tool_dict[db_tool_model.id] = [
_build_search_tool(db_tool_model.id, search_tool_config)
]
search_tool = SearchTool(
tool_id=db_tool_model.id,
emitter=emitter,
user=user,
persona=persona,
llm=llm,
document_index=document_index,
user_selected_filters=search_tool_config.user_selected_filters,
project_id_filter=search_tool_config.project_id_filter,
persona_id_filter=search_tool_config.persona_id_filter,
bypass_acl=search_tool_config.bypass_acl,
slack_context=search_tool_config.slack_context,
enable_slack_search=search_tool_config.enable_slack_search,
)
tool_dict[db_tool_model.id] = [search_tool]
# Handle Image Generation Tool
elif tool_cls.__name__ == ImageGenerationTool.__name__:
@@ -436,12 +421,26 @@ def construct_tools(
# Get the database tool model for SearchTool
search_tool_db_model = get_builtin_tool(db_session, SearchTool)
# Use the passed-in config if available, otherwise create a new one
if not search_tool_config:
search_tool_config = SearchToolConfig()
tool_dict[search_tool_db_model.id] = [
_build_search_tool(search_tool_db_model.id, search_tool_config)
]
search_tool = SearchTool(
tool_id=search_tool_db_model.id,
emitter=emitter,
user=user,
persona=persona,
llm=llm,
document_index=document_index,
user_selected_filters=search_tool_config.user_selected_filters,
project_id_filter=search_tool_config.project_id_filter,
persona_id_filter=search_tool_config.persona_id_filter,
bypass_acl=search_tool_config.bypass_acl,
slack_context=search_tool_config.slack_context,
enable_slack_search=search_tool_config.enable_slack_search,
)
tool_dict[search_tool_db_model.id] = [search_tool]
# Always inject MemoryTool when the user has the memory tool enabled,
# bypassing persona tool associations and allowed_tool_ids filtering

View File

@@ -51,7 +51,6 @@ from onyx.context.search.models import ChunkSearchRequest
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import PersonaSearchInfo
from onyx.context.search.models import SearchDocsResponse
from onyx.context.search.pipeline import merge_individual_chunks
from onyx.context.search.pipeline import search_pipeline
@@ -66,6 +65,7 @@ from onyx.db.federated import (
get_federated_connector_document_set_mappings_by_document_set_names,
)
from onyx.db.federated import list_federated_connector_oauth_tokens
from onyx.db.models import Persona
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.db.search_settings import get_current_search_settings
@@ -238,8 +238,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
emitter: Emitter,
# Used for ACLs and federated search, anonymous users only see public docs
user: User,
# Pre-extracted persona search configuration
persona_search_info: PersonaSearchInfo,
# Used for filter settings
persona: Persona,
llm: LLM,
document_index: DocumentIndex,
# Respecting user selections
@@ -258,7 +258,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
super().__init__(emitter=emitter)
self.user = user
self.persona_search_info = persona_search_info
self.persona = persona
self.llm = llm
self.document_index = document_index
self.user_selected_filters = user_selected_filters
@@ -289,7 +289,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
# Case 1: Slack bot context — requires a Slack federated connector
# linked via the persona's document sets
if self.slack_context:
document_set_names = self.persona_search_info.document_set_names
document_set_names = [ds.name for ds in self.persona.document_sets]
if not document_set_names:
logger.debug(
"Skipping Slack federated search: no document sets on persona"
@@ -463,7 +463,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
persona_id_filter=self.persona_id_filter,
document_index=self.document_index,
user=self.user,
persona_search_info=self.persona_search_info,
persona=self.persona,
acl_filters=acl_filters,
embedding_model=embedding_model,
prefetched_federated_retrieval_infos=federated_retrieval_infos,
@@ -587,12 +587,15 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
and self.user_selected_filters.source_type
else None
)
persona_document_sets = (
[ds.name for ds in self.persona.document_sets] if self.persona else None
)
federated_retrieval_infos = (
get_federated_retrieval_functions(
db_session=db_session,
user_id=self.user.id if self.user else None,
source_types=prefetch_source_types,
document_set_names=self.persona_search_info.document_set_names,
document_set_names=persona_document_sets,
)
or []
)

View File

@@ -549,7 +549,7 @@ mypy-extensions==1.0.0
# typing-inspect
nest-asyncio==1.6.0
# via onyx
nltk==3.9.4
nltk==3.9.3
# via unstructured
numpy==2.4.1
# via
@@ -752,7 +752,7 @@ pypandoc-binary==1.16.2
# via onyx
pyparsing==3.2.5
# via httplib2
pypdf==6.9.2
pypdf==6.9.1
# via
# onyx
# unstructured-client
@@ -861,7 +861,7 @@ regex==2025.11.3
# dateparser
# nltk
# tiktoken
requests==2.33.0
requests==2.32.5
# via
# atlassian-python-api
# braintrust

View File

@@ -410,7 +410,7 @@ release-tag==0.5.2
# via onyx
reorder-python-imports-black==3.14.0
# via onyx
requests==2.33.0
requests==2.32.5
# via
# cohere
# google-genai

View File

@@ -244,7 +244,7 @@ referencing==0.36.2
# jsonschema-specifications
regex==2025.11.3
# via tiktoken
requests==2.33.0
requests==2.32.5
# via
# cohere
# google-genai

View File

@@ -338,7 +338,7 @@ regex==2025.11.3
# via
# tiktoken
# transformers
requests==2.33.0
requests==2.32.5
# via
# cohere
# google-genai

View File

@@ -153,15 +153,13 @@ class TestAdapterWritesBothMetadataFields:
doc = chunk.source_document
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert aware_chunk.user_project == []
@@ -190,15 +188,13 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert project.id in aware_chunk.user_project
assert aware_chunk.personas == []
@@ -229,14 +225,13 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert project.id in aware_chunk.user_project
@@ -261,14 +256,13 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert aware_chunk.personas == []
assert aware_chunk.user_project == []
@@ -300,12 +294,11 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}

View File

@@ -103,11 +103,6 @@ _EXPECTED_CONFLUENCE_GROUPS = [
user_emails={"oauth@onyx.app"},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="no yuhong allowed",
user_emails={"hagen@danswer.ai", "pablo@onyx.app", "chris@onyx.app"},
gives_anyone_access=False,
),
]

View File

@@ -1,50 +0,0 @@
from sqlalchemy import inspect
from sqlalchemy.orm import Session
from onyx.db.chat import create_chat_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.models import Persona
from onyx.db.models import UserProject
from tests.external_dependency_unit.conftest import create_test_user
def test_eager_load_persona_loads_relationships(db_session: Session) -> None:
"""Verify that eager_load_persona pre-loads persona, its collections, and project."""
user = create_test_user(db_session, "eager-load")
persona = Persona(name="eager-load-test", description="test")
project = UserProject(name="eager-load-project", user_id=user.id)
db_session.add_all([persona, project])
db_session.flush()
chat_session = create_chat_session(
db_session=db_session,
description="test",
user_id=None,
persona_id=persona.id,
project_id=project.id,
)
loaded = get_chat_session_by_id(
chat_session_id=chat_session.id,
user_id=None,
db_session=db_session,
eager_load_persona=True,
)
try:
tmp = inspect(loaded)
assert tmp is not None
unloaded = tmp.unloaded
assert "persona" not in unloaded
assert "project" not in unloaded
tmp = inspect(loaded.persona)
assert tmp is not None
persona_unloaded = tmp.unloaded
assert "tools" not in persona_unloaded
assert "user_files" not in persona_unloaded
assert "document_sets" not in persona_unloaded
assert "attached_documents" not in persona_unloaded
assert "hierarchy_nodes" not in persona_unloaded
finally:
db_session.rollback()

View File

@@ -6,6 +6,7 @@ These tests assume Vespa and OpenSearch are running.
import time
import uuid
from collections.abc import Generator
from collections.abc import Iterator
import httpx
import pytest
@@ -21,6 +22,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import DocMetadataAwareIndexChunk
from tests.external_dependency_unit.constants import TEST_TENANT_ID
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
from tests.external_dependency_unit.document_index.conftest import make_chunk
@@ -201,3 +203,25 @@ class TestDocumentIndexNew:
assert len(result_map) == 2
assert result_map[existing_doc] is True
assert result_map[new_doc] is False
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndexNew],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
doc_id = f"test_gen_{uuid.uuid4().hex[:8]}"
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[3])
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk(doc_id, chunk_id=i)
results = document_index.index(
chunks=chunk_gen(), indexing_metadata=metadata
)
assert len(results) == 1
assert results[0].document_id == doc_id
assert results[0].already_existed is False

View File

@@ -5,6 +5,7 @@ These tests assume Vespa and OpenSearch are running.
import time
from collections.abc import Generator
from collections.abc import Iterator
import pytest
@@ -166,3 +167,29 @@ class TestDocumentIndexOld:
batch_retrieval=True,
)
assert len(inference_chunks) == 0
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndex],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk("test_doc_gen", chunk_id=i)
index_batch_params = IndexBatchParams(
doc_id_to_previous_chunk_cnt={"test_doc_gen": 0},
doc_id_to_new_chunk_cnt={"test_doc_gen": 3},
tenant_id=get_current_tenant_id(),
large_chunks_enabled=False,
)
results = document_index.index(chunk_gen(), index_batch_params)
assert len(results) == 1
record = results.pop()
assert record.document_id == "test_doc_gen"
assert record.already_existed is False

View File

@@ -11,8 +11,8 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import ChunkSearchRequest
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import PersonaSearchInfo
from onyx.context.search.models import SearchDoc
from onyx.db.models import Persona
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
@@ -139,12 +139,12 @@ def use_mock_search_pipeline(
chunk_search_request: ChunkSearchRequest,
document_index: DocumentIndex, # noqa: ARG001
user: User | None, # noqa: ARG001
persona_search_info: PersonaSearchInfo | None, # noqa: ARG001
persona: Persona | None, # noqa: ARG001
db_session: Session | None = None, # noqa: ARG001
auto_detect_filters: bool = False, # noqa: ARG001
llm: LLM | None = None, # noqa: ARG001
project_id_filter: int | None = None, # noqa: ARG001
persona_id_filter: int | None = None, # noqa: ARG001
project_id: int | None = None, # noqa: ARG001
persona_id: int | None = None, # noqa: ARG001
# Pre-fetched data (used by SearchTool to avoid DB access in parallel calls)
acl_filters: list[str] | None = None, # noqa: ARG001
embedding_model: EmbeddingModel | None = None, # noqa: ARG001

Some files were not shown because too many files have changed in this diff Show More