mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-26 10:02:42 +00:00
Compare commits
51 Commits
dane/vecto
...
bo/hook_ui
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c09649ab42 | ||
|
|
4539e164ba | ||
|
|
360a1ce281 | ||
|
|
ef5628bfa7 | ||
|
|
6ffee0021e | ||
|
|
28dc84b831 | ||
|
|
230f035500 | ||
|
|
55b24d72b4 | ||
|
|
3321a84c7d | ||
|
|
54bf32a5f8 | ||
|
|
4bb6b76be6 | ||
|
|
db94562474 | ||
|
|
582d4642c1 | ||
|
|
3caaecdb0e | ||
|
|
039b69806b | ||
|
|
63971d4958 | ||
|
|
ffd897f380 | ||
|
|
4745069232 | ||
|
|
386782f188 | ||
|
|
ff009c4129 | ||
|
|
b20a5ebf69 | ||
|
|
8645adb807 | ||
|
|
2425bd4d8d | ||
|
|
333b2b19cb | ||
|
|
44895b3bd6 | ||
|
|
78c2ecf99f | ||
|
|
e3e0e04edc | ||
|
|
a19fe03bd8 | ||
|
|
415c05b5f8 | ||
|
|
352fd19f0a | ||
|
|
41ae039bfa | ||
|
|
782c734287 | ||
|
|
728cdb0715 | ||
|
|
baf6437117 | ||
|
|
f187165077 | ||
|
|
727be3d663 | ||
|
|
98c8f9884b | ||
|
|
d79a068984 | ||
|
|
ba0740d15f | ||
|
|
86b7bed90b | ||
|
|
aead6ab9a5 | ||
|
|
c9d4c186dd | ||
|
|
70aad1ec46 | ||
|
|
ca3cc16ead | ||
|
|
9ea1780ce5 | ||
|
|
f70e5e605e | ||
|
|
84b134e226 | ||
|
|
b17c63a7d6 | ||
|
|
76c41d1b0b | ||
|
|
579b86f1ce | ||
|
|
a53cf13db1 |
@@ -6,3 +6,4 @@
|
||||
|
||||
3134e5f840c12c8f32613ce520101a047c89dcc2 # refactor(whitespace): rm temporary react fragments (#7161)
|
||||
ed3f72bc75f3e3a9ae9e4d8cd38278f9c97e78b4 # refactor(whitespace): rm react fragment #7190
|
||||
7b927e79c25f4ddfd18a067f489e122acd2c89de # chore(format): format files where `ruff` and `black` agree (#9339)
|
||||
|
||||
@@ -7,6 +7,15 @@ on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths:
|
||||
- "backend/**"
|
||||
- "pyproject.toml"
|
||||
- "uv.lock"
|
||||
- ".github/workflows/pr-external-dependency-unit-tests.yml"
|
||||
- ".github/actions/setup-python-and-install-dependencies/**"
|
||||
- ".github/actions/setup-playwright/**"
|
||||
- "deployment/docker_compose/docker-compose.yml"
|
||||
- "deployment/docker_compose/docker-compose.dev.yml"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
@@ -7,6 +7,13 @@ on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths:
|
||||
- "backend/**"
|
||||
- "pyproject.toml"
|
||||
- "uv.lock"
|
||||
- ".github/workflows/pr-python-connector-tests.yml"
|
||||
- ".github/actions/setup-python-and-install-dependencies/**"
|
||||
- ".github/actions/setup-playwright/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
12
.vscode/launch.json
vendored
12
.vscode/launch.json
vendored
@@ -117,7 +117,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
"consoleTitle": "API Server Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
@@ -268,7 +269,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
"consoleTitle": "Celery heavy Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery kg_processing",
|
||||
@@ -355,7 +357,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user_file_processing Console"
|
||||
"consoleTitle": "Celery user_file_processing Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery docfetching",
|
||||
@@ -413,7 +416,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery docprocessing Console"
|
||||
"consoleTitle": "Celery docprocessing Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
"""group_permissions_phase1
|
||||
|
||||
Revision ID: 25a5501dc766
|
||||
Revises: b728689f45b1
|
||||
Create Date: 2026-03-23 11:41:25.557442
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "25a5501dc766"
|
||||
down_revision = "b728689f45b1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. Add account_type column to user table (nullable for now).
|
||||
# TODO(subash): backfill account_type for existing rows and add NOT NULL.
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"account_type",
|
||||
sa.Enum(AccountType, native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# 2. Add is_default column to user_group table
|
||||
op.add_column(
|
||||
"user_group",
|
||||
sa.Column(
|
||||
"is_default",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# 3. Create permission_grant table
|
||||
op.create_table(
|
||||
"permission_grant",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("group_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"permission",
|
||||
sa.Enum(Permission, native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"grant_source",
|
||||
sa.Enum(GrantSource, native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"granted_by",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"granted_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_deleted",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["group_id"],
|
||||
["user_group.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["granted_by"],
|
||||
["user.id"],
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"group_id", "permission", name="uq_permission_grant_group_permission"
|
||||
),
|
||||
)
|
||||
|
||||
# 4. Index on user__user_group(user_id) — existing composite PK
|
||||
# has user_group_id as leading column; user-filtered queries need this
|
||||
op.create_index(
|
||||
"ix_user__user_group_user_id",
|
||||
"user__user_group",
|
||||
["user_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_user__user_group_user_id", table_name="user__user_group")
|
||||
op.drop_table("permission_grant")
|
||||
op.drop_column("user_group", "is_default")
|
||||
op.drop_column("user", "account_type")
|
||||
@@ -25,10 +25,13 @@ from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
# Soft time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
# Hard time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||
# Maximum tenants to provision in a single task run.
|
||||
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
|
||||
_MAX_TENANTS_PER_RUN = 5
|
||||
|
||||
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -85,9 +88,26 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
f"To provision: {tenants_to_provision}"
|
||||
)
|
||||
|
||||
# just provision one tenant each time we run this ... increase if needed.
|
||||
if tenants_to_provision > 0:
|
||||
pre_provision_tenant()
|
||||
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
|
||||
if batch_size < tenants_to_provision:
|
||||
task_logger.info(
|
||||
f"Capping batch to {batch_size} "
|
||||
f"(need {tenants_to_provision}, will catch up next cycle)"
|
||||
)
|
||||
|
||||
provisioned = 0
|
||||
for i in range(batch_size):
|
||||
task_logger.info(f"Provisioning tenant {i + 1}/{batch_size}")
|
||||
try:
|
||||
if pre_provision_tenant():
|
||||
provisioned += 1
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Failed to provision tenant {i + 1}/{batch_size}, "
|
||||
"continuing with remaining tenants"
|
||||
)
|
||||
|
||||
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
@@ -101,11 +121,13 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
)
|
||||
|
||||
|
||||
def pre_provision_tenant() -> None:
|
||||
def pre_provision_tenant() -> bool:
|
||||
"""
|
||||
Pre-provision a new tenant and store it in the NewAvailableTenant table.
|
||||
This function fully sets up the tenant with all necessary configurations,
|
||||
so it's ready to be assigned to a user immediately.
|
||||
|
||||
Returns True if a tenant was successfully provisioned, False otherwise.
|
||||
"""
|
||||
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
|
||||
# rather than inside this function
|
||||
@@ -118,10 +140,10 @@ def pre_provision_tenant() -> None:
|
||||
|
||||
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
|
||||
if not lock_provision.acquire(blocking=False):
|
||||
task_logger.debug(
|
||||
"Skipping pre_provision_tenant task because it is already running"
|
||||
task_logger.warning(
|
||||
"Skipping pre_provision_tenant — could not acquire provision lock"
|
||||
)
|
||||
return
|
||||
return False
|
||||
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
@@ -161,6 +183,7 @@ def pre_provision_tenant() -> None:
|
||||
db_session.add(new_tenant)
|
||||
db_session.commit()
|
||||
task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}")
|
||||
return True
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
task_logger.error(
|
||||
@@ -184,6 +207,7 @@ def pre_provision_tenant() -> None:
|
||||
asyncio.run(rollback_tenant_provisioning(tenant_id))
|
||||
except Exception:
|
||||
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
lock_provision.release()
|
||||
|
||||
@@ -115,8 +115,14 @@ def fetch_user_group_token_rate_limits_for_user(
|
||||
ordered: bool = True,
|
||||
get_editable: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
stmt = select(TokenRateLimit)
|
||||
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
|
||||
stmt = (
|
||||
select(TokenRateLimit)
|
||||
.join(
|
||||
TokenRateLimit__UserGroup,
|
||||
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
|
||||
)
|
||||
.where(TokenRateLimit__UserGroup.user_group_id == group_id)
|
||||
)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
if enabled_only:
|
||||
|
||||
@@ -800,6 +800,33 @@ def update_user_group(
|
||||
return db_user_group
|
||||
|
||||
|
||||
def rename_user_group(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
new_name: str,
|
||||
) -> UserGroup:
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
db_user_group = db_session.scalar(stmt)
|
||||
if db_user_group is None:
|
||||
raise ValueError(f"UserGroup with id '{user_group_id}' not found")
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
db_user_group.name = new_name
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
# CC pair documents in Vespa contain the group name, so we need to
|
||||
# trigger a sync to update them with the new name.
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
db_user_group.is_up_to_date = False
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
|
||||
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
db_user_group = db_session.scalar(stmt)
|
||||
|
||||
@@ -56,7 +56,7 @@ def _run_single_search(
|
||||
chunk_search_request=chunk_search_request,
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
persona=None, # No persona for direct search
|
||||
persona_search_info=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.persona import update_persona_access
|
||||
from ee.onyx.db.user_group import add_users_to_user_group
|
||||
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_group
|
||||
@@ -11,13 +12,16 @@ from ee.onyx.db.user_group import fetch_user_groups
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.db.user_group import insert_user_group
|
||||
from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from ee.onyx.db.user_group import rename_user_group
|
||||
from ee.onyx.db.user_group import update_user_curator_relationship
|
||||
from ee.onyx.db.user_group import update_user_group
|
||||
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
|
||||
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import UpdateGroupAgentsRequest
|
||||
from ee.onyx.server.user_group.models import UserGroup
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupRename
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
@@ -27,6 +31,9 @@ from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -87,6 +94,32 @@ def create_user_group(
|
||||
return UserGroup.from_model(db_user_group)
|
||||
|
||||
|
||||
@router.patch("/admin/user-group/rename")
|
||||
def rename_user_group_endpoint(
|
||||
rename_request: UserGroupRename,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
return UserGroup.from_model(
|
||||
rename_user_group(
|
||||
db_session=db_session,
|
||||
user_group_id=rename_request.id,
|
||||
new_name=rename_request.name,
|
||||
)
|
||||
)
|
||||
except IntegrityError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"User group with name '{rename_request.name}' already exists.",
|
||||
)
|
||||
except ValueError as e:
|
||||
msg = str(e)
|
||||
if "not found" in msg.lower():
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, msg)
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, msg)
|
||||
|
||||
|
||||
@router.patch("/admin/user-group/{user_group_id}")
|
||||
def patch_user_group(
|
||||
user_group_id: int,
|
||||
@@ -161,3 +194,38 @@ def delete_user_group(
|
||||
user_group = fetch_user_group(db_session, user_group_id)
|
||||
if user_group:
|
||||
db_delete_user_group(db_session, user_group)
|
||||
|
||||
|
||||
@router.patch("/admin/user-group/{user_group_id}/agents")
|
||||
def update_group_agents(
|
||||
user_group_id: int,
|
||||
request: UpdateGroupAgentsRequest,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
for agent_id in request.added_agent_ids:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=agent_id, user=user, db_session=db_session
|
||||
)
|
||||
current_group_ids = [g.id for g in persona.groups]
|
||||
if user_group_id not in current_group_ids:
|
||||
update_persona_access(
|
||||
persona_id=agent_id,
|
||||
creator_user_id=user.id,
|
||||
db_session=db_session,
|
||||
group_ids=current_group_ids + [user_group_id],
|
||||
)
|
||||
|
||||
for agent_id in request.removed_agent_ids:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=agent_id, user=user, db_session=db_session
|
||||
)
|
||||
current_group_ids = [g.id for g in persona.groups]
|
||||
update_persona_access(
|
||||
persona_id=agent_id,
|
||||
creator_user_id=user.id,
|
||||
db_session=db_session,
|
||||
group_ids=[gid for gid in current_group_ids if gid != user_group_id],
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -104,6 +104,16 @@ class AddUsersToUserGroupRequest(BaseModel):
|
||||
user_ids: list[UUID]
|
||||
|
||||
|
||||
class UserGroupRename(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class SetCuratorRequest(BaseModel):
|
||||
user_id: UUID
|
||||
is_curator: bool
|
||||
|
||||
|
||||
class UpdateGroupAgentsRequest(BaseModel):
|
||||
added_agent_ids: list[int]
|
||||
removed_agent_ids: list[int]
|
||||
|
||||
@@ -13,6 +13,14 @@ from celery.signals import worker_shutdown
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -34,6 +42,8 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
on_indexing_task_prerun(task_id, task, kwargs)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -48,6 +58,36 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
on_indexing_task_postrun(task_id, task, kwargs, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_retry signal doesn't pass task_id in kwargs; get it from
|
||||
# the sender (the task instance) via sender.request.id.
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_rejected sends the Consumer as sender, not the task instance.
|
||||
# The task name must be extracted from the Celery message headers.
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -76,6 +116,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("docfetching")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,14 @@ from celery.signals import worker_shutdown
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -35,6 +43,8 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
on_indexing_task_prerun(task_id, task, kwargs)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -49,6 +59,36 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
on_indexing_task_postrun(task_id, task, kwargs, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_retry signal doesn't pass task_id in kwargs; get it from
|
||||
# the sender (the task instance) via sender.request.id.
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_rejected sends the Consumer as sender, not the task instance.
|
||||
# The task name must be extracted from the Celery message headers.
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -82,6 +122,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("docprocessing")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@@ -90,6 +131,12 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
# Note: worker_process_init only fires in prefork pool mode. Docprocessing uses
|
||||
# worker_pool="threads" (see configs/docprocessing.py), so this handler is
|
||||
# effectively a no-op in normal operation. It remains as a safety net in case
|
||||
# the pool type is ever changed to prefork. Prometheus metrics are safe in
|
||||
# thread-pool mode since all threads share the same process memory and can
|
||||
# update the same Counter/Gauge/Histogram objects directly.
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
@@ -54,8 +54,14 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
# Set by on_worker_init so on_worker_ready knows whether to start the server.
|
||||
_prometheus_collectors_ok: bool = False
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
global _prometheus_collectors_ok
|
||||
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
@@ -65,6 +71,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
_prometheus_collectors_ok = _setup_prometheus_collectors(sender)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
@@ -72,8 +80,37 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
def _setup_prometheus_collectors(sender: Any) -> bool:
|
||||
"""Register Prometheus collectors that need Redis/DB access.
|
||||
|
||||
Passes the Celery app so the queue depth collector can obtain a fresh
|
||||
broker Redis client on each scrape (rather than holding a stale reference).
|
||||
|
||||
Returns True if registration succeeded, False otherwise.
|
||||
"""
|
||||
try:
|
||||
from onyx.server.metrics.indexing_pipeline_setup import (
|
||||
setup_indexing_pipeline_metrics,
|
||||
)
|
||||
|
||||
setup_indexing_pipeline_metrics(sender.app)
|
||||
logger.info("Prometheus indexing pipeline collectors registered")
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to register Prometheus indexing pipeline collectors")
|
||||
return False
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
if _prometheus_collectors_ok:
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
|
||||
start_metrics_server("monitoring")
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping Prometheus metrics server — collector registration failed"
|
||||
)
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
@@ -68,11 +69,19 @@ from onyx.db.models import UserFile
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import verify_user_files
|
||||
from onyx.hooks.executor import execute_hook
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.hooks.points.query_processing import QueryProcessingPayload
|
||||
from onyx.hooks.points.query_processing import QueryProcessingResponse
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -424,6 +433,28 @@ def determine_search_params(
|
||||
)
|
||||
|
||||
|
||||
def _resolve_query_processing_hook_result(
|
||||
hook_result: QueryProcessingResponse | HookSkipped | HookSoftFailed,
|
||||
message_text: str,
|
||||
) -> str:
|
||||
"""Apply the Query Processing hook result to the message text.
|
||||
|
||||
Returns the (possibly rewritten) message text, or raises OnyxError with
|
||||
QUERY_REJECTED if the hook signals rejection (query is null or empty).
|
||||
HookSkipped and HookSoftFailed are pass-throughs — the original text is
|
||||
returned unchanged.
|
||||
"""
|
||||
if isinstance(hook_result, (HookSkipped, HookSoftFailed)):
|
||||
return message_text
|
||||
if not (hook_result.query and hook_result.query.strip()):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.QUERY_REJECTED,
|
||||
hook_result.rejection_message
|
||||
or "The hook extension for query processing did not return a valid query. No rejection reason was provided.",
|
||||
)
|
||||
return hook_result.query.strip()
|
||||
|
||||
|
||||
def handle_stream_message_objects(
|
||||
new_msg_req: SendMessageRequest,
|
||||
user: User,
|
||||
@@ -474,16 +505,24 @@ def handle_stream_message_objects(
|
||||
db_session=db_session,
|
||||
)
|
||||
yield CreateChatSessionID(chat_session_id=chat_session.id)
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
else:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
|
||||
persona = chat_session.persona
|
||||
|
||||
message_text = new_msg_req.message
|
||||
|
||||
user_identity = LLMUserIdentity(
|
||||
user_id=llm_user_identifier, session_id=str(chat_session.id)
|
||||
)
|
||||
@@ -575,6 +614,28 @@ def handle_stream_message_objects(
|
||||
if parent_message.message_type == MessageType.USER:
|
||||
user_message = parent_message
|
||||
else:
|
||||
# New message — run the Query Processing hook before saving to DB.
|
||||
# Skipped on regeneration: the message already exists and was accepted previously.
|
||||
# Skip the hook for empty/whitespace-only messages — no meaningful query
|
||||
# to process, and SendMessageRequest.message has no min_length guard.
|
||||
if message_text.strip():
|
||||
hook_result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=QueryProcessingPayload(
|
||||
query=message_text,
|
||||
# Pass None for anonymous users or authenticated users without an email
|
||||
# (e.g. some SSO flows). QueryProcessingPayload.user_email is str | None,
|
||||
# so None is accepted and serialised as null in both cases.
|
||||
user_email=None if user.is_anonymous else user.email,
|
||||
chat_session_id=str(chat_session.id),
|
||||
).model_dump(),
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
message_text = _resolve_query_processing_hook_result(
|
||||
hook_result, message_text
|
||||
)
|
||||
|
||||
user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=parent_message,
|
||||
@@ -914,6 +975,17 @@ def handle_stream_message_objects(
|
||||
state_container=state_container,
|
||||
)
|
||||
|
||||
except OnyxError as e:
|
||||
if e.error_code is not OnyxErrorCode.QUERY_REJECTED:
|
||||
log_onyx_error(e)
|
||||
yield StreamingError(
|
||||
error=e.detail,
|
||||
error_code=e.error_code.code,
|
||||
is_retryable=e.status_code >= 500,
|
||||
)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
|
||||
@@ -787,10 +787,6 @@ MINI_CHUNK_SIZE = 150
|
||||
# This is the number of regular chunks per large chunk
|
||||
LARGE_CHUNK_RATIO = 4
|
||||
|
||||
# The maximum number of chunks that can be held for 1 document processing batch
|
||||
# The purpose of this is to set an upper bound on memory usage
|
||||
MAX_CHUNKS_PER_DOC_BATCH = int(os.environ.get("MAX_CHUNKS_PER_DOC_BATCH") or 1000)
|
||||
|
||||
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
|
||||
# We don't want the metadata to overwhelm the actual contents of the chunk
|
||||
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
|
||||
|
||||
@@ -123,7 +123,7 @@ class OnyxConfluence:
|
||||
|
||||
self.shared_base_kwargs: dict[str, str | int | bool] = {
|
||||
"api_version": "cloud" if is_cloud else "latest",
|
||||
"backoff_and_retry": True,
|
||||
"backoff_and_retry": False,
|
||||
"cloud": is_cloud,
|
||||
}
|
||||
if timeout:
|
||||
@@ -456,7 +456,7 @@ class OnyxConfluence:
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
|
||||
)
|
||||
|
||||
@@ -363,7 +363,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except requests.HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
|
||||
)
|
||||
@@ -384,7 +384,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
def _handle_http_error(e: requests.HTTPError, attempt: int, max_retries: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
@@ -408,6 +408,17 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
|
||||
raise e
|
||||
|
||||
if e.response.status_code >= 500:
|
||||
if attempt >= max_retries - 1:
|
||||
raise e
|
||||
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
logger.warning(
|
||||
f"Server error {e.response.status_code}. "
|
||||
f"Retrying in {delay} seconds (attempt {attempt + 1})..."
|
||||
)
|
||||
return math.ceil(time.monotonic() + delay)
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
|
||||
@@ -53,7 +53,7 @@ class NotionPage(BaseModel):
|
||||
id: str
|
||||
created_time: str
|
||||
last_edited_time: str
|
||||
archived: bool
|
||||
in_trash: bool
|
||||
properties: dict[str, Any]
|
||||
url: str
|
||||
|
||||
@@ -63,6 +63,13 @@ class NotionPage(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class NotionDataSource(BaseModel):
|
||||
"""Represents a Notion Data Source within a database."""
|
||||
|
||||
id: str
|
||||
name: str = ""
|
||||
|
||||
|
||||
class NotionBlock(BaseModel):
|
||||
"""Represents a Notion Block object"""
|
||||
|
||||
@@ -107,7 +114,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
"Notion-Version": "2026-03-11",
|
||||
}
|
||||
self.indexed_pages: set[str] = set()
|
||||
self.root_page_id = root_page_id
|
||||
@@ -127,6 +134,9 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
# Maps child page IDs to their containing page ID (discovered in _read_blocks).
|
||||
# Used to resolve block_id parent types to the actual containing page.
|
||||
self._child_page_parent_map: dict[str, str] = {}
|
||||
# Maps data_source_id -> database_id (populated in _read_pages_from_database).
|
||||
# Used to resolve data_source_id parent types back to the database.
|
||||
self._data_source_to_database_map: dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
@@ -227,7 +237,11 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database_as_page(self, database_id: str) -> NotionPage:
|
||||
"""Attempt to fetch a database as a page."""
|
||||
"""Attempt to fetch a database as a page.
|
||||
|
||||
Note: As of API 2025-09-03, database objects no longer include
|
||||
`properties` (schema moved to individual data sources).
|
||||
"""
|
||||
logger.debug(f"Fetching database for ID '{database_id}' as a page")
|
||||
database_url = f"https://api.notion.com/v1/databases/{database_id}"
|
||||
res = rl_requests.get(
|
||||
@@ -246,18 +260,52 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
database_name[0].get("text", {}).get("content") if database_name else None
|
||||
)
|
||||
|
||||
db_data.setdefault("properties", {})
|
||||
|
||||
return NotionPage(**db_data, database_name=database_name)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database(
|
||||
self, database_id: str, cursor: str | None = None
|
||||
def _fetch_data_sources_for_database(
|
||||
self, database_id: str
|
||||
) -> list[NotionDataSource]:
|
||||
"""Fetch the list of data sources for a database."""
|
||||
logger.debug(f"Fetching data sources for database '{database_id}'")
|
||||
res = rl_requests.get(
|
||||
f"https://api.notion.com/v1/databases/{database_id}",
|
||||
headers=self.headers,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
)
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except Exception as e:
|
||||
if res.status_code in (403, 404):
|
||||
logger.error(
|
||||
f"Unable to access database with ID '{database_id}'. "
|
||||
f"This is likely due to the database not being shared "
|
||||
f"with the Onyx integration. Exact exception:\n{e}"
|
||||
)
|
||||
return []
|
||||
logger.exception(f"Error fetching database - {res.json()}")
|
||||
raise e
|
||||
|
||||
db_data = res.json()
|
||||
data_sources = db_data.get("data_sources", [])
|
||||
return [
|
||||
NotionDataSource(id=ds["id"], name=ds.get("name", ""))
|
||||
for ds in data_sources
|
||||
if ds.get("id")
|
||||
]
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_data_source(
|
||||
self, data_source_id: str, cursor: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch a database from it's ID via the Notion API."""
|
||||
logger.debug(f"Fetching database for ID '{database_id}'")
|
||||
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
|
||||
"""Query a data source via POST /v1/data_sources/{id}/query."""
|
||||
logger.debug(f"Querying data source '{data_source_id}'")
|
||||
url = f"https://api.notion.com/v1/data_sources/{data_source_id}/query"
|
||||
body = None if not cursor else {"start_cursor": cursor}
|
||||
res = rl_requests.post(
|
||||
block_url,
|
||||
url,
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
@@ -265,25 +313,14 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except Exception as e:
|
||||
json_data = res.json()
|
||||
code = json_data.get("code")
|
||||
# Sep 3 2025 backend changed the error message for this case
|
||||
# TODO: it is also now possible for there to be multiple data sources per database; at present we
|
||||
# just don't handle that. We will need to upgrade the API to the current version + query the
|
||||
# new data sources endpoint to handle that case correctly.
|
||||
if code == "object_not_found" or (
|
||||
code == "validation_error"
|
||||
and "does not contain any data sources" in json_data.get("message", "")
|
||||
):
|
||||
# this happens when a database is not shared with the integration
|
||||
# in this case, we should just ignore the database
|
||||
if res.status_code in (403, 404):
|
||||
logger.error(
|
||||
f"Unable to access database with ID '{database_id}'. "
|
||||
f"This is likely due to the database not being shared "
|
||||
f"Unable to access data source with ID '{data_source_id}'. "
|
||||
f"This is likely due to it not being shared "
|
||||
f"with the Onyx integration. Exact exception:\n{e}"
|
||||
)
|
||||
return {"results": [], "next_cursor": None}
|
||||
logger.exception(f"Error fetching database - {res.json()}")
|
||||
logger.exception(f"Error querying data source - {res.json()}")
|
||||
raise e
|
||||
return res.json()
|
||||
|
||||
@@ -348,8 +385,9 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
# Fallback to workspace if we don't know the parent
|
||||
return self.workspace_id
|
||||
elif parent_type == "data_source_id":
|
||||
# Newer Notion API may use data_source_id for databases
|
||||
return parent.get("database_id") or parent.get("data_source_id")
|
||||
ds_id = parent.get("data_source_id")
|
||||
if ds_id:
|
||||
return self._data_source_to_database_map.get(ds_id, self.workspace_id)
|
||||
elif parent_type in ["page_id", "database_id"]:
|
||||
return parent.get(parent_type)
|
||||
|
||||
@@ -497,18 +535,32 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
if db_node:
|
||||
hierarchy_nodes.append(db_node)
|
||||
|
||||
cursor = None
|
||||
while True:
|
||||
data = self._fetch_database(database_id, cursor)
|
||||
# Discover all data sources under this database, then query each one.
|
||||
# Even legacy single-source databases have one entry in the array.
|
||||
data_sources = self._fetch_data_sources_for_database(database_id)
|
||||
if not data_sources:
|
||||
logger.warning(
|
||||
f"Database '{database_id}' returned zero data sources — "
|
||||
f"no pages will be indexed from this database."
|
||||
)
|
||||
for ds in data_sources:
|
||||
self._data_source_to_database_map[ds.id] = database_id
|
||||
cursor = None
|
||||
while True:
|
||||
data = self._fetch_data_source(ds.id, cursor)
|
||||
|
||||
for result in data["results"]:
|
||||
obj_id = result["id"]
|
||||
obj_type = result["object"]
|
||||
text = self._properties_to_str(result.get("properties", {}))
|
||||
if text:
|
||||
result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n"))
|
||||
for result in data["results"]:
|
||||
obj_id = result["id"]
|
||||
obj_type = result["object"]
|
||||
text = self._properties_to_str(result.get("properties", {}))
|
||||
if text:
|
||||
result_blocks.append(
|
||||
NotionBlock(id=obj_id, text=text, prefix="\n")
|
||||
)
|
||||
|
||||
if not self.recursive_index_enabled:
|
||||
continue
|
||||
|
||||
if self.recursive_index_enabled:
|
||||
if obj_type == "page":
|
||||
logger.debug(
|
||||
f"Found page with ID '{obj_id}' in database '{database_id}'"
|
||||
@@ -518,7 +570,6 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
logger.debug(
|
||||
f"Found database with ID '{obj_id}' in database '{database_id}'"
|
||||
)
|
||||
# Get nested database name from properties if available
|
||||
nested_db_title = result.get("title", [])
|
||||
nested_db_name = None
|
||||
if nested_db_title and len(nested_db_title) > 0:
|
||||
@@ -533,10 +584,10 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
result_pages.extend(nested_output.child_page_ids)
|
||||
hierarchy_nodes.extend(nested_output.hierarchy_nodes)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
|
||||
cursor = data["next_cursor"]
|
||||
cursor = data["next_cursor"]
|
||||
|
||||
return BlockReadOutput(
|
||||
blocks=result_blocks,
|
||||
@@ -807,36 +858,55 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
def _yield_database_hierarchy_nodes(
|
||||
self,
|
||||
) -> Generator[HierarchyNode | Document, None, None]:
|
||||
"""Search for all databases and yield hierarchy nodes for each.
|
||||
"""Search for all data sources and yield hierarchy nodes for their parent databases.
|
||||
|
||||
This must be called BEFORE page indexing so that database hierarchy nodes
|
||||
exist when pages inside databases reference them as parents.
|
||||
|
||||
With the new API, search returns data source objects instead of databases.
|
||||
Multiple data sources can share the same parent database, so we use
|
||||
database_id as the hierarchy node key and deduplicate via
|
||||
_maybe_yield_hierarchy_node.
|
||||
"""
|
||||
query_dict: dict[str, Any] = {
|
||||
"filter": {"property": "object", "value": "database"},
|
||||
"filter": {"property": "object", "value": "data_source"},
|
||||
"page_size": _NOTION_PAGE_SIZE,
|
||||
}
|
||||
pages_seen = 0
|
||||
while pages_seen < _MAX_PAGES:
|
||||
db_res = self._search_notion(query_dict)
|
||||
for db in db_res.results:
|
||||
db_id = db["id"]
|
||||
# Extract title from the title array
|
||||
title_arr = db.get("title", [])
|
||||
db_name = None
|
||||
if title_arr:
|
||||
db_name = " ".join(
|
||||
t.get("plain_text", "") for t in title_arr
|
||||
).strip()
|
||||
if not db_name:
|
||||
for ds in db_res.results:
|
||||
# Extract the parent database_id from the data source's parent
|
||||
ds_parent = ds.get("parent", {})
|
||||
db_id = ds_parent.get("database_id")
|
||||
if not db_id:
|
||||
continue
|
||||
|
||||
# Populate the mapping so _get_parent_raw_id can resolve later
|
||||
ds_id = ds.get("id")
|
||||
if not ds_id:
|
||||
continue
|
||||
self._data_source_to_database_map[ds_id] = db_id
|
||||
|
||||
# Fetch the database to get its actual name and parent
|
||||
try:
|
||||
db_page = self._fetch_database_as_page(db_id)
|
||||
db_name = db_page.database_name or f"Database {db_id}"
|
||||
parent_raw_id = self._get_parent_raw_id(db_page.parent)
|
||||
db_url = (
|
||||
db_page.url or f"https://notion.so/{db_id.replace('-', '')}"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(
|
||||
f"Could not fetch database '{db_id}', "
|
||||
f"defaulting to workspace root. Error: {e}"
|
||||
)
|
||||
db_name = f"Database {db_id}"
|
||||
parent_raw_id = self.workspace_id
|
||||
db_url = f"https://notion.so/{db_id.replace('-', '')}"
|
||||
|
||||
# Get parent using existing helper
|
||||
parent_raw_id = self._get_parent_raw_id(db.get("parent"))
|
||||
|
||||
# Notion URLs omit dashes from UUIDs
|
||||
db_url = db.get("url") or f"https://notion.so/{db_id.replace('-', '')}"
|
||||
|
||||
# _maybe_yield_hierarchy_node deduplicates by raw_node_id,
|
||||
# so multiple data sources under one database produce one node.
|
||||
node = self._maybe_yield_hierarchy_node(
|
||||
raw_node_id=db_id,
|
||||
raw_parent_id=parent_raw_id or self.workspace_id,
|
||||
|
||||
@@ -401,3 +401,16 @@ class SavedSearchDocWithContent(SavedSearchDoc):
|
||||
section in addition to the match_highlights."""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
class PersonaSearchInfo(BaseModel):
|
||||
"""Snapshot of persona data needed by the search pipeline.
|
||||
|
||||
Extracted from the ORM Persona before the DB session is released so that
|
||||
SearchTool and search_pipeline never lazy-load relationships post-commit.
|
||||
"""
|
||||
|
||||
document_set_names: list[str]
|
||||
search_start_date: datetime | None
|
||||
attached_document_ids: list[str]
|
||||
hierarchy_node_ids: list[int]
|
||||
|
||||
@@ -9,12 +9,12 @@ from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.context.search.preprocessing.access_filters import (
|
||||
build_access_filters_for_user,
|
||||
)
|
||||
from onyx.context.search.retrieval.search_runner import search_chunks
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
@@ -247,8 +247,8 @@ def search_pipeline(
|
||||
document_index: DocumentIndex,
|
||||
# Used for ACLs and federated search, anonymous users only see public docs
|
||||
user: User,
|
||||
# Used for default filters and settings
|
||||
persona: Persona | None,
|
||||
# Pre-extracted persona search configuration (None when no persona)
|
||||
persona_search_info: PersonaSearchInfo | None,
|
||||
db_session: Session | None = None,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
@@ -263,24 +263,18 @@ def search_pipeline(
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
persona_document_sets: list[str] | None = (
|
||||
[persona_document_set.name for persona_document_set in persona.document_sets]
|
||||
if persona
|
||||
else None
|
||||
persona_search_info.document_set_names if persona_search_info else None
|
||||
)
|
||||
persona_time_cutoff: datetime | None = (
|
||||
persona.search_start_date if persona else None
|
||||
persona_search_info.search_start_date if persona_search_info else None
|
||||
)
|
||||
|
||||
# Extract assistant knowledge filters from persona
|
||||
attached_document_ids: list[str] | None = (
|
||||
[doc.id for doc in persona.attached_documents]
|
||||
if persona and persona.attached_documents
|
||||
persona_search_info.attached_document_ids or None
|
||||
if persona_search_info
|
||||
else None
|
||||
)
|
||||
hierarchy_node_ids: list[int] | None = (
|
||||
[node.id for node in persona.hierarchy_nodes]
|
||||
if persona and persona.hierarchy_nodes
|
||||
else None
|
||||
persona_search_info.hierarchy_node_ids or None if persona_search_info else None
|
||||
)
|
||||
|
||||
filters = _build_index_filters(
|
||||
|
||||
@@ -16,6 +16,7 @@ from sqlalchemy import Row
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.exc import MultipleResultsFound
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -28,6 +29,7 @@ from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import ChatSessionSharedStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DBSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
@@ -53,9 +55,22 @@ def get_chat_session_by_id(
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
is_shared: bool = False,
|
||||
eager_load_persona: bool = False,
|
||||
) -> ChatSession:
|
||||
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
|
||||
|
||||
if eager_load_persona:
|
||||
stmt = stmt.options(
|
||||
joinedload(ChatSession.persona).options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.attached_documents),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
),
|
||||
joinedload(ChatSession.project),
|
||||
)
|
||||
|
||||
if is_shared:
|
||||
stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC)
|
||||
else:
|
||||
|
||||
@@ -750,3 +750,31 @@ def resync_cc_pair(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_connector_health_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list: # Returns list of Row tuples
|
||||
"""Return connector health data for Prometheus metrics.
|
||||
|
||||
Each row is (cc_pair_id, status, in_repeated_error_state,
|
||||
last_successful_index_time, name, source).
|
||||
"""
|
||||
return (
|
||||
db_session.query(
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.status,
|
||||
ConnectorCredentialPair.in_repeated_error_state,
|
||||
ConnectorCredentialPair.last_successful_index_time,
|
||||
ConnectorCredentialPair.name,
|
||||
Connector.source,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -1,4 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum as PyEnum
|
||||
from typing import ClassVar
|
||||
|
||||
|
||||
class AccountType(str, PyEnum):
|
||||
"""
|
||||
What kind of account this is — determines whether the user
|
||||
enters the group-based permission system.
|
||||
|
||||
STANDARD + SERVICE_ACCOUNT → participate in group system
|
||||
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
|
||||
"""
|
||||
|
||||
STANDARD = "standard"
|
||||
BOT = "bot"
|
||||
EXT_PERM_USER = "ext_perm_user"
|
||||
SERVICE_ACCOUNT = "service_account"
|
||||
ANONYMOUS = "anonymous"
|
||||
|
||||
|
||||
class GrantSource(str, PyEnum):
|
||||
"""How a permission grant was created."""
|
||||
|
||||
USER = "user"
|
||||
SCIM = "scim"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
@@ -314,3 +341,54 @@ class HookPoint(str, PyEnum):
|
||||
class HookFailStrategy(str, PyEnum):
|
||||
HARD = "hard" # exception propagates, pipeline aborts
|
||||
SOFT = "soft" # log error, return original input, pipeline continues
|
||||
|
||||
|
||||
class Permission(str, PyEnum):
|
||||
"""
|
||||
Permission tokens for group-based authorization.
|
||||
19 tokens total. full_admin_panel_access is an override —
|
||||
if present, any permission check passes.
|
||||
"""
|
||||
|
||||
# Basic (auto-granted to every new group)
|
||||
BASIC_ACCESS = "basic"
|
||||
|
||||
# Read tokens — implied only, never granted directly
|
||||
READ_CONNECTORS = "read:connectors"
|
||||
READ_DOCUMENT_SETS = "read:document_sets"
|
||||
READ_AGENTS = "read:agents"
|
||||
READ_USERS = "read:users"
|
||||
|
||||
# Add / Manage pairs
|
||||
ADD_AGENTS = "add:agents"
|
||||
MANAGE_AGENTS = "manage:agents"
|
||||
MANAGE_DOCUMENT_SETS = "manage:document_sets"
|
||||
ADD_CONNECTORS = "add:connectors"
|
||||
MANAGE_CONNECTORS = "manage:connectors"
|
||||
MANAGE_LLMS = "manage:llms"
|
||||
|
||||
# Toggle tokens
|
||||
READ_AGENT_ANALYTICS = "read:agent_analytics"
|
||||
MANAGE_ACTIONS = "manage:actions"
|
||||
READ_QUERY_HISTORY = "read:query_history"
|
||||
MANAGE_USER_GROUPS = "manage:user_groups"
|
||||
CREATE_USER_API_KEYS = "create:user_api_keys"
|
||||
CREATE_SERVICE_ACCOUNT_API_KEYS = "create:service_account_api_keys"
|
||||
CREATE_SLACK_DISCORD_BOTS = "create:slack_discord_bots"
|
||||
|
||||
# Override — any permission check passes
|
||||
FULL_ADMIN_PANEL_ACCESS = "admin"
|
||||
|
||||
# Permissions that are implied by other grants and must never be stored
|
||||
# directly in the permission_grant table.
|
||||
IMPLIED: ClassVar[frozenset[Permission]]
|
||||
|
||||
|
||||
Permission.IMPLIED = frozenset(
|
||||
{
|
||||
Permission.READ_CONNECTORS,
|
||||
Permission.READ_DOCUMENT_SETS,
|
||||
Permission.READ_AGENTS,
|
||||
Permission.READ_USERS,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -75,6 +75,7 @@ def create_hook__no_commit(
|
||||
fail_strategy: HookFailStrategy,
|
||||
timeout_seconds: float,
|
||||
is_active: bool = False,
|
||||
is_reachable: bool | None = None,
|
||||
creator_id: UUID | None = None,
|
||||
) -> Hook:
|
||||
"""Create a new hook for the given hook point.
|
||||
@@ -100,6 +101,7 @@ def create_hook__no_commit(
|
||||
fail_strategy=fail_strategy,
|
||||
timeout_seconds=timeout_seconds,
|
||||
is_active=is_active,
|
||||
is_reachable=is_reachable,
|
||||
creator_id=creator_id,
|
||||
)
|
||||
# Use a savepoint so that a failed insert only rolls back this operation,
|
||||
|
||||
@@ -2,6 +2,8 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import NamedTuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -28,6 +30,9 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# from sqlalchemy.sql.selectable import Select
|
||||
|
||||
# Comment out unused imports that cause mypy errors
|
||||
@@ -972,3 +977,106 @@ def get_index_attempt_errors_for_cc_pair(
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class ActiveIndexAttemptMetric(NamedTuple):
|
||||
"""Row returned by get_active_index_attempts_for_metrics."""
|
||||
|
||||
status: IndexingStatus
|
||||
source: "DocumentSource"
|
||||
cc_pair_id: int
|
||||
cc_pair_name: str | None
|
||||
attempt_count: int
|
||||
|
||||
|
||||
def get_active_index_attempts_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list[ActiveIndexAttemptMetric]:
|
||||
"""Return non-terminal index attempts grouped by status, source, and connector.
|
||||
|
||||
Each row is (status, source, cc_pair_id, cc_pair_name, attempt_count).
|
||||
"""
|
||||
from onyx.db.models import Connector
|
||||
|
||||
terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
func.count(),
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.filter(IndexAttempt.status.notin_(terminal_statuses))
|
||||
.group_by(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return [ActiveIndexAttemptMetric(*row) for row in rows]
|
||||
|
||||
|
||||
def get_failed_attempt_counts_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: failed_attempt_count} for all connectors.
|
||||
|
||||
When ``since`` is provided, only attempts created after that timestamp
|
||||
are counted. Defaults to the last 90 days to avoid unbounded historical
|
||||
aggregation.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.count(),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.FAILED)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.all()
|
||||
)
|
||||
return {cc_id: count for cc_id, count in rows}
|
||||
|
||||
|
||||
def get_docs_indexed_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: total_new_docs_indexed} across successful attempts.
|
||||
|
||||
Only counts attempts with status SUCCESS to avoid inflating counts with
|
||||
partial results from failed attempts. When ``since`` is provided, only
|
||||
attempts created after that timestamp are included.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
|
||||
query = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.sum(func.coalesce(IndexAttempt.new_docs_indexed, 0)),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.SUCCESS)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
)
|
||||
rows = query.all()
|
||||
return {cc_id: int(total or 0) for cc_id, total in rows}
|
||||
|
||||
@@ -48,6 +48,7 @@ from sqlalchemy.types import LargeBinary
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import (
|
||||
ANONYMOUS_USER_UUID,
|
||||
@@ -78,6 +79,8 @@ from onyx.db.enums import (
|
||||
MCPAuthenticationPerformer,
|
||||
MCPTransport,
|
||||
MCPServerStatus,
|
||||
Permission,
|
||||
GrantSource,
|
||||
LLMModelFlowType,
|
||||
ThemePreference,
|
||||
DefaultAppMode,
|
||||
@@ -302,6 +305,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
account_type: Mapped[AccountType | None] = mapped_column(
|
||||
Enum(AccountType, native_enum=False), nullable=True
|
||||
)
|
||||
|
||||
"""
|
||||
Preferences probably should be in a separate table at some point, but for now
|
||||
@@ -3971,6 +3977,8 @@ class SamlAccount(Base):
|
||||
class User__UserGroup(Base):
|
||||
__tablename__ = "user__user_group"
|
||||
|
||||
__table_args__ = (Index("ix_user__user_group_user_id", "user_id"),)
|
||||
|
||||
is_curator: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
@@ -3981,6 +3989,48 @@ class User__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class PermissionGrant(Base):
|
||||
__tablename__ = "permission_grant"
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"group_id", "permission", name="uq_permission_grant_group_permission"
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
permission: Mapped[Permission] = mapped_column(
|
||||
Enum(Permission, native_enum=False), nullable=False
|
||||
)
|
||||
grant_source: Mapped[GrantSource] = mapped_column(
|
||||
Enum(GrantSource, native_enum=False), nullable=False
|
||||
)
|
||||
granted_by: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
granted_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
is_deleted: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False, server_default=text("false")
|
||||
)
|
||||
|
||||
group: Mapped["UserGroup"] = relationship(
|
||||
"UserGroup", back_populates="permission_grants"
|
||||
)
|
||||
|
||||
@validates("permission")
|
||||
def _validate_permission(self, _key: str, value: Permission) -> Permission:
|
||||
if value in Permission.IMPLIED:
|
||||
raise ValueError(
|
||||
f"{value!r} is an implied permission and cannot be granted directly"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class UserGroup__ConnectorCredentialPair(Base):
|
||||
__tablename__ = "user_group__connector_credential_pair"
|
||||
|
||||
@@ -4075,6 +4125,8 @@ class UserGroup(Base):
|
||||
is_up_for_deletion: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
# whether this is a default group (e.g. "Basic", "Admins") that cannot be deleted
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Last time a user updated this user group
|
||||
time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -4118,6 +4170,9 @@ class UserGroup(Base):
|
||||
accessible_mcp_servers: Mapped[list["MCPServer"]] = relationship(
|
||||
"MCPServer", secondary="mcp_server__user_group", back_populates="user_groups"
|
||||
)
|
||||
permission_grants: Mapped[list["PermissionGrant"]] = relationship(
|
||||
"PermissionGrant", back_populates="group", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
"""Tables related to Token Rate Limiting
|
||||
|
||||
@@ -50,8 +50,18 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_default_behavior_persona(db_session: Session) -> Persona | None:
|
||||
def get_default_behavior_persona(
|
||||
db_session: Session,
|
||||
eager_load_for_tools: bool = False,
|
||||
) -> Persona | None:
|
||||
stmt = select(Persona).where(Persona.id == DEFAULT_PERSONA_ID)
|
||||
if eager_load_for_tools:
|
||||
stmt = stmt.options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.attached_documents),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
)
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
|
||||
out against a nonexistent Vespa/OpenSearch instance.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from onyx.context.search.models import IndexFilters
|
||||
@@ -67,7 +66,7 @@ class DisabledDocumentIndex(DocumentIndex):
|
||||
# ------------------------------------------------------------------
|
||||
def index(
|
||||
self,
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
index_batch_params: IndexBatchParams, # noqa: ARG002
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
@@ -207,7 +206,7 @@ class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""
|
||||
@@ -227,8 +226,8 @@ class Indexable(abc.ABC):
|
||||
it is done automatically outside of this code.
|
||||
|
||||
Parameters:
|
||||
- chunks: Document chunks with all of the information needed for
|
||||
indexing to the document index.
|
||||
- chunks: Document chunks with all of the information needed for indexing to the document
|
||||
index.
|
||||
- tenant_id: The tenant id of the user whose chunks are being indexed
|
||||
- large_chunks_enabled: Whether large chunks are enabled
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -210,10 +209,10 @@ class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
"""Indexes an iterable of document chunks into the document index.
|
||||
"""Indexes a list of document chunks into the document index.
|
||||
|
||||
This is often a batch operation including chunks from multiple
|
||||
documents.
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -351,7 +350,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
@@ -647,10 +646,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata, # noqa: ARG002
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
"""Indexes an iterable of document chunks into the document index.
|
||||
"""Indexes a list of document chunks into the document index.
|
||||
|
||||
Groups chunks by document ID and for each document, deletes existing
|
||||
chunks and indexes the new chunks in bulk.
|
||||
@@ -673,34 +672,29 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
document is newly indexed or had already existed and was just
|
||||
updated.
|
||||
"""
|
||||
total_chunks = sum(
|
||||
cc.new_chunk_cnt
|
||||
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
|
||||
# Group chunks by document ID.
|
||||
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
|
||||
list
|
||||
)
|
||||
for chunk in chunks:
|
||||
doc_id_to_chunks[chunk.source_document.id].append(chunk)
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
|
||||
f"documents for index {self._index_name}."
|
||||
)
|
||||
|
||||
document_indexing_results: list[DocumentInsertionRecord] = []
|
||||
deleted_doc_ids: set[str] = set()
|
||||
# Buffer chunks per document as they arrive from the iterable.
|
||||
# When the document ID changes flush the buffered chunks.
|
||||
current_doc_id: str | None = None
|
||||
current_chunks: list[DocMetadataAwareIndexChunk] = []
|
||||
|
||||
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
|
||||
assert len(doc_chunks) > 0, "doc_chunks is empty"
|
||||
|
||||
# Try to index per-document.
|
||||
for _, chunks in doc_id_to_chunks.items():
|
||||
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
|
||||
# Since we are doing this in batches, an error occurring midway
|
||||
# can result in a state where chunks are deleted and not all the
|
||||
# new chunks have been indexed.
|
||||
# Do this before deleting existing chunks to reduce the amount of
|
||||
# time the document index has no content for a given document, and
|
||||
# to reduce the chance of entering a state where we delete chunks,
|
||||
# then some error happens, and never successfully index new chunks.
|
||||
chunk_batch: list[DocumentChunk] = [
|
||||
_convert_onyx_chunk_to_opensearch_document(chunk)
|
||||
for chunk in doc_chunks
|
||||
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
|
||||
]
|
||||
onyx_document: Document = doc_chunks[0].source_document
|
||||
onyx_document: Document = chunks[0].source_document
|
||||
# First delete the doc's chunks from the index. This is so that
|
||||
# there are no dangling chunks in the index, in the event that the
|
||||
# new document's content contains fewer chunks than the previous
|
||||
@@ -709,43 +703,22 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# if the chunk count has actually decreased. This assumes that
|
||||
# overlapping chunks are perfectly overwritten. If we can't
|
||||
# guarantee that then we need the code as-is.
|
||||
if onyx_document.id not in deleted_doc_ids:
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
deleted_doc_ids.add(onyx_document.id)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed. We record the result before bulk_index_documents
|
||||
# runs. If indexing raises, this entire result list is discarded
|
||||
# by the caller's retry logic, so early recording is safe.
|
||||
document_indexing_results.append(
|
||||
DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
)
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed.
|
||||
document_insertion_record = DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
# Now index. This will raise if a chunk of the same ID exists, which
|
||||
# we do not expect because we should have deleted all chunks.
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunk_batch,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
|
||||
for chunk in chunks:
|
||||
doc_id = chunk.source_document.id
|
||||
if doc_id != current_doc_id:
|
||||
if current_chunks:
|
||||
_flush_chunks(current_chunks)
|
||||
current_doc_id = doc_id
|
||||
current_chunks = [chunk]
|
||||
elif len(current_chunks) >= MAX_CHUNKS_PER_DOC_BATCH:
|
||||
_flush_chunks(current_chunks)
|
||||
current_chunks = [chunk]
|
||||
else:
|
||||
current_chunks.append(chunk)
|
||||
|
||||
if current_chunks:
|
||||
_flush_chunks(current_chunks)
|
||||
document_indexing_results.append(document_insertion_record)
|
||||
|
||||
return document_indexing_results
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import re
|
||||
import time
|
||||
import urllib
|
||||
import zipfile
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@@ -462,7 +461,7 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
@@ -10,7 +8,6 @@ import httpx
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
|
||||
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
|
||||
from onyx.configs.app_configs import RERANK_COUNT
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
@@ -321,7 +318,7 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
|
||||
@@ -341,31 +338,22 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
# Vespa has restrictions on valid characters, yet document IDs come from
|
||||
# external w.r.t. this class. We need to sanitize them.
|
||||
#
|
||||
# Instead of materializing all cleaned chunks upfront, we stream them
|
||||
# through a generator that cleans IDs and builds the original-ID mapping
|
||||
# incrementally as chunks flow into Vespa.
|
||||
def _clean_and_track(
|
||||
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
|
||||
id_map: dict[str, str],
|
||||
seen_ids: set[str],
|
||||
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
|
||||
"""Cleans chunk IDs and builds the original-ID mapping
|
||||
incrementally as chunks flow through, avoiding a separate
|
||||
materialization pass."""
|
||||
for chunk in chunks_iter:
|
||||
original_id = chunk.source_document.id
|
||||
cleaned = clean_chunk_id_copy(chunk)
|
||||
cleaned_id = cleaned.source_document.id
|
||||
# Needed so the final DocumentInsertionRecord returned can have
|
||||
# the original document ID. cleaned_chunks might not contain IDs
|
||||
# exactly as callers supplied them.
|
||||
id_map[cleaned_id] = original_id
|
||||
seen_ids.add(cleaned_id)
|
||||
yield cleaned
|
||||
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
|
||||
clean_chunk_id_copy(chunk) for chunk in chunks
|
||||
]
|
||||
assert len(cleaned_chunks) == len(
|
||||
chunks
|
||||
), "Bug: Cleaned chunks and input chunks have different lengths."
|
||||
|
||||
new_document_id_to_original_document_id: dict[str, str] = {}
|
||||
all_cleaned_doc_ids: set[str] = set()
|
||||
# Needed so the final DocumentInsertionRecord returned can have the
|
||||
# original document ID. cleaned_chunks might not contain IDs exactly as
|
||||
# callers supplied them.
|
||||
new_document_id_to_original_document_id: dict[str, str] = dict()
|
||||
for i, cleaned_chunk in enumerate(cleaned_chunks):
|
||||
old_chunk = chunks[i]
|
||||
new_document_id_to_original_document_id[
|
||||
cleaned_chunk.source_document.id
|
||||
] = old_chunk.source_document.id
|
||||
|
||||
existing_docs: set[str] = set()
|
||||
|
||||
@@ -421,16 +409,8 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
# Insert new Vespa documents, streaming through the cleaning
|
||||
# pipeline so chunks are never fully materialized.
|
||||
cleaned_chunks = _clean_and_track(
|
||||
chunks,
|
||||
new_document_id_to_original_document_id,
|
||||
all_cleaned_doc_ids,
|
||||
)
|
||||
for chunk_batch in batch_generator(
|
||||
cleaned_chunks, min(BATCH_SIZE, MAX_CHUNKS_PER_DOC_BATCH)
|
||||
):
|
||||
# Insert new Vespa documents.
|
||||
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
|
||||
batch_index_vespa_chunks(
|
||||
chunks=chunk_batch,
|
||||
index_name=self._index_name,
|
||||
@@ -439,6 +419,10 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
all_cleaned_doc_ids: set[str] = {
|
||||
chunk.source_document.id for chunk in cleaned_chunks
|
||||
}
|
||||
|
||||
return [
|
||||
DocumentInsertionRecord(
|
||||
document_id=new_document_id_to_original_document_id[cleaned_doc_id],
|
||||
|
||||
@@ -44,6 +44,7 @@ class OnyxErrorCode(Enum):
|
||||
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
|
||||
INVALID_INPUT = ("INVALID_INPUT", 400)
|
||||
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
|
||||
QUERY_REJECTED = ("QUERY_REJECTED", 400)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Not Found (404)
|
||||
|
||||
@@ -5,6 +5,7 @@ Usage (Celery tasks and FastAPI handlers):
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
if isinstance(result, HookSkipped):
|
||||
@@ -14,7 +15,7 @@ Usage (Celery tasks and FastAPI handlers):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is the response payload dict from the customer's endpoint
|
||||
# result is a validated Pydantic model instance (response_type)
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
@@ -53,9 +54,11 @@ The executor uses three sessions:
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -81,6 +84,9 @@ class HookSoftFailed:
|
||||
"""Hook was called but failed with SOFT fail strategy — continuing."""
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -268,22 +274,21 @@ def _persist_result(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def execute_hook(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
def _execute_hook_inner(
|
||||
hook: Hook,
|
||||
payload: dict[str, Any],
|
||||
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
|
||||
"""Execute the hook for the given hook point synchronously."""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
response_type: type[T],
|
||||
) -> T | HookSoftFailed:
|
||||
"""Make the HTTP call, validate the response, and return a typed model.
|
||||
|
||||
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
|
||||
"""
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
@@ -300,13 +305,36 @@ def execute_hook(
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(timeout=timeout) as client:
|
||||
with httpx.Client(
|
||||
timeout=timeout, follow_redirects=False
|
||||
) as client: # SSRF guard: never follow redirects
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
|
||||
# Validate the response payload against response_type.
|
||||
# A validation failure downgrades the outcome to a failure so it is logged,
|
||||
# is_reachable is left unchanged (server responded — just a bad payload),
|
||||
# and fail_strategy is respected below.
|
||||
validated_model: T | None = None
|
||||
if outcome.is_success and outcome.response_payload is not None:
|
||||
try:
|
||||
validated_model = response_type.model_validate(outcome.response_payload)
|
||||
except ValidationError as e:
|
||||
msg = (
|
||||
f"Hook response failed validation against {response_type.__name__}: {e}"
|
||||
)
|
||||
outcome = _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=outcome.status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
@@ -323,8 +351,41 @@ def execute_hook(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
if outcome.response_payload is None:
|
||||
raise ValueError(
|
||||
f"response_payload is None for successful hook call (hook_id={hook_id})"
|
||||
|
||||
if validated_model is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"validated_model is None for successful hook call (hook_id={hook_id})",
|
||||
)
|
||||
return outcome.response_payload
|
||||
return validated_model
|
||||
|
||||
|
||||
def execute_hook(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""Execute the hook for the given hook point synchronously.
|
||||
|
||||
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
|
||||
hook failed with SOFT fail strategy, or a validated response model on success.
|
||||
Raises OnyxError on HARD failure or if the hook is misconfigured.
|
||||
"""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
fail_strategy = hook.fail_strategy
|
||||
hook_id = hook.id
|
||||
|
||||
try:
|
||||
return _execute_hook_inner(hook, payload, response_type)
|
||||
except Exception:
|
||||
if fail_strategy == HookFailStrategy.SOFT:
|
||||
logger.exception(
|
||||
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
raise
|
||||
|
||||
@@ -91,6 +91,8 @@ class HookResponse(BaseModel):
|
||||
# Nullable to match the DB column — endpoint_url is required on creation but
|
||||
# future hook point types may not use an external endpoint (e.g. built-in handlers).
|
||||
endpoint_url: str | None
|
||||
# Partially-masked API key (e.g. "abcd••••••••wxyz"), or None if no key is set.
|
||||
api_key_masked: str | None
|
||||
fail_strategy: HookFailStrategy
|
||||
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
|
||||
is_active: bool
|
||||
|
||||
@@ -51,13 +51,12 @@ class HookPointSpec:
|
||||
output_schema: ClassVar[dict[str, Any]]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: object) -> None:
|
||||
"""Enforce that every concrete subclass declares all required class attributes.
|
||||
"""Enforce that every subclass declares all required class attributes.
|
||||
|
||||
Called automatically by Python whenever a class inherits from HookPointSpec.
|
||||
Abstract subclasses (those still carrying unimplemented abstract methods) are
|
||||
skipped — they are intermediate base classes and may not yet define everything.
|
||||
Only fully concrete subclasses are validated, ensuring a clear TypeError at
|
||||
import time rather than a confusing AttributeError at runtime.
|
||||
Raises TypeError at import time if any required attribute is missing or if
|
||||
payload_model / response_model are not Pydantic BaseModel subclasses.
|
||||
input_schema and output_schema are derived automatically from the models.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
|
||||
|
||||
@@ -26,6 +26,8 @@ class DocumentIngestionSpec(HookPointSpec):
|
||||
default_timeout_seconds = 30.0
|
||||
fail_hard_description = "The document will not be indexed."
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
# TODO(Bo-Onyx): update later
|
||||
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.ue263ual5vdi"
|
||||
|
||||
payload_model = DocumentIngestionPayload
|
||||
response_model = DocumentIngestionResponse
|
||||
|
||||
@@ -15,7 +15,7 @@ class QueryProcessingPayload(BaseModel):
|
||||
description="Email of the user submitting the query, or null if unauthenticated."
|
||||
)
|
||||
chat_session_id: str = Field(
|
||||
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
|
||||
description="UUID of the chat session, formatted as a hyphenated lowercase string (e.g. '550e8400-e29b-41d4-a716-446655440000'). Always present — the session is guaranteed to exist by the time this hook fires."
|
||||
)
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class QueryProcessingResponse(BaseModel):
|
||||
default=None,
|
||||
description=(
|
||||
"The query to use in the pipeline. "
|
||||
"Null, empty string, or absent = reject the query."
|
||||
"Null, empty string, whitespace-only, or absent = reject the query."
|
||||
),
|
||||
)
|
||||
rejection_message: str | None = Field(
|
||||
@@ -65,6 +65,8 @@ class QueryProcessingSpec(HookPointSpec):
|
||||
"The query will be blocked and the user will see an error message."
|
||||
)
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
# TODO(Bo-Onyx): update later
|
||||
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.g2r1a1699u87"
|
||||
|
||||
payload_model = QueryProcessingPayload
|
||||
response_model = QueryProcessingResponse
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Generator
|
||||
|
||||
@@ -21,8 +19,7 @@ from onyx.db.document import update_docs_updated_at__no_commit
|
||||
from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.models import ChunkEnrichmentContext
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import BuildMetadataAwareChunksResult
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
@@ -88,21 +85,14 @@ class DocumentIndexingBatchAdapter:
|
||||
) as transaction:
|
||||
yield transaction
|
||||
|
||||
def prepare_enrichment(
|
||||
def build_metadata_aware_chunks(
|
||||
self,
|
||||
context: DocumentBatchPrepareContext,
|
||||
chunks_with_embeddings: list[IndexChunk],
|
||||
chunk_content_scores: list[float],
|
||||
tenant_id: str,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> DocumentChunkEnricher:
|
||||
"""Do all DB lookups once and return a per-chunk enricher."""
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
doc_id: 0 for doc_id in updatable_ids
|
||||
}
|
||||
for chunk in chunks:
|
||||
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
|
||||
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
|
||||
context: DocumentBatchPrepareContext,
|
||||
) -> BuildMetadataAwareChunksResult:
|
||||
"""Enrich chunks with access, document sets, boosts, token counts, and hierarchy."""
|
||||
|
||||
no_access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
@@ -112,30 +102,67 @@ class DocumentIndexingBatchAdapter:
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
return DocumentChunkEnricher(
|
||||
doc_id_to_access_info=get_access_for_documents(
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
doc_id_to_access_info = get_access_for_documents(
|
||||
document_ids=updatable_ids, db_session=self.db_session
|
||||
)
|
||||
doc_id_to_document_set = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=updatable_ids, db_session=self.db_session
|
||||
),
|
||||
doc_id_to_document_set={
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=updatable_ids, db_session=self.db_session
|
||||
)
|
||||
},
|
||||
doc_id_to_ancestor_ids=self._get_ancestor_ids_for_documents(
|
||||
context.updatable_docs, tenant_id
|
||||
),
|
||||
id_to_boost_map=context.id_to_boost_map,
|
||||
doc_id_to_previous_chunk_cnt={
|
||||
document_id: chunk_count
|
||||
for document_id, chunk_count in fetch_chunk_counts_for_documents(
|
||||
document_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
},
|
||||
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
|
||||
no_access=no_access,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
}
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int] = {
|
||||
document_id: chunk_count
|
||||
for document_id, chunk_count in fetch_chunk_counts_for_documents(
|
||||
document_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
}
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
doc_id: 0 for doc_id in updatable_ids
|
||||
}
|
||||
for chunk in chunks_with_embeddings:
|
||||
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
|
||||
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
|
||||
|
||||
# Get ancestor hierarchy node IDs for each document
|
||||
doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents(
|
||||
context.updatable_docs, tenant_id
|
||||
)
|
||||
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=doc_id_to_access_info.get(chunk.source_document.id, no_access),
|
||||
document_sets=set(
|
||||
doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=(
|
||||
context.id_to_boost_map[chunk.source_document.id]
|
||||
if chunk.source_document.id in context.id_to_boost_map
|
||||
else DEFAULT_BOOST
|
||||
),
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
ancestor_hierarchy_node_ids=doc_id_to_ancestor_ids[
|
||||
chunk.source_document.id
|
||||
],
|
||||
)
|
||||
for chunk_num, chunk in enumerate(chunks_with_embeddings)
|
||||
]
|
||||
|
||||
return BuildMetadataAwareChunksResult(
|
||||
chunks=access_aware_chunks,
|
||||
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=doc_id_to_new_chunk_cnt,
|
||||
user_file_id_to_raw_text={},
|
||||
user_file_id_to_token_count={},
|
||||
)
|
||||
|
||||
def _get_ancestor_ids_for_documents(
|
||||
@@ -176,7 +203,7 @@ class DocumentIndexingBatchAdapter:
|
||||
context: DocumentBatchPrepareContext,
|
||||
updatable_chunk_data: list[UpdatableChunkData],
|
||||
filtered_documents: list[Document],
|
||||
enrichment: ChunkEnrichmentContext,
|
||||
result: BuildMetadataAwareChunksResult,
|
||||
) -> None:
|
||||
"""Finalize DB updates, store plaintext, and mark docs as indexed."""
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
@@ -200,7 +227,7 @@ class DocumentIndexingBatchAdapter:
|
||||
|
||||
update_docs_chunk_count__no_commit(
|
||||
document_ids=updatable_ids,
|
||||
doc_id_to_chunk_count=enrichment.doc_id_to_new_chunk_cnt,
|
||||
doc_id_to_chunk_count=result.doc_id_to_new_chunk_cnt,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
|
||||
@@ -222,52 +249,3 @@ class DocumentIndexingBatchAdapter:
|
||||
)
|
||||
|
||||
self.db_session.commit()
|
||||
|
||||
|
||||
class DocumentChunkEnricher:
|
||||
"""Pre-computed metadata for per-chunk enrichment of connector documents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
doc_id_to_access_info: dict[str, DocumentAccess],
|
||||
doc_id_to_document_set: dict[str, list[str]],
|
||||
doc_id_to_ancestor_ids: dict[str, list[int]],
|
||||
id_to_boost_map: dict[str, int],
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int],
|
||||
doc_id_to_new_chunk_cnt: dict[str, int],
|
||||
no_access: DocumentAccess,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
self._doc_id_to_access_info = doc_id_to_access_info
|
||||
self._doc_id_to_document_set = doc_id_to_document_set
|
||||
self._doc_id_to_ancestor_ids = doc_id_to_ancestor_ids
|
||||
self._id_to_boost_map = id_to_boost_map
|
||||
self._no_access = no_access
|
||||
self._tenant_id = tenant_id
|
||||
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
|
||||
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
|
||||
|
||||
def enrich_chunk(
|
||||
self, chunk: IndexChunk, score: float
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=self._doc_id_to_access_info.get(
|
||||
chunk.source_document.id, self._no_access
|
||||
),
|
||||
document_sets=set(
|
||||
self._doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=(
|
||||
self._id_to_boost_map[chunk.source_document.id]
|
||||
if chunk.source_document.id in self._id_to_boost_map
|
||||
else DEFAULT_BOOST
|
||||
),
|
||||
tenant_id=self._tenant_id,
|
||||
aggregated_chunk_boost_factor=score,
|
||||
ancestor_hierarchy_node_ids=self._doc_id_to_ancestor_ids[
|
||||
chunk.source_document.id
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from uuid import UUID
|
||||
|
||||
@@ -27,8 +24,7 @@ from onyx.db.user_file import fetch_persona_ids_for_user_files
|
||||
from onyx.db.user_file import fetch_user_project_ids_for_user_files
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
from onyx.indexing.models import ChunkEnrichmentContext
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import BuildMetadataAwareChunksResult
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
@@ -105,20 +101,13 @@ class UserFileIndexingAdapter:
|
||||
f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts for user files: {[doc.id for doc in documents]}"
|
||||
)
|
||||
|
||||
def prepare_enrichment(
|
||||
def build_metadata_aware_chunks(
|
||||
self,
|
||||
context: DocumentBatchPrepareContext,
|
||||
chunks_with_embeddings: list[IndexChunk],
|
||||
chunk_content_scores: list[float],
|
||||
tenant_id: str,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> UserFileChunkEnricher:
|
||||
"""Do all DB lookups and pre-compute file metadata from chunks."""
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = defaultdict(int)
|
||||
content_by_file: dict[str, list[str]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
|
||||
content_by_file[chunk.source_document.id].append(chunk.content)
|
||||
context: DocumentBatchPrepareContext,
|
||||
) -> BuildMetadataAwareChunksResult:
|
||||
|
||||
no_access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
@@ -128,6 +117,7 @@ class UserFileIndexingAdapter:
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
user_file_id_to_project_ids = fetch_user_project_ids_for_user_files(
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
@@ -148,6 +138,17 @@ class UserFileIndexingAdapter:
|
||||
)
|
||||
}
|
||||
|
||||
user_file_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
user_file_id: len(
|
||||
[
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == user_file_id
|
||||
]
|
||||
)
|
||||
for user_file_id in updatable_ids
|
||||
}
|
||||
|
||||
# Initialize tokenizer used for token count calculation
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
@@ -162,9 +163,15 @@ class UserFileIndexingAdapter:
|
||||
user_file_id_to_raw_text: dict[str, str] = {}
|
||||
user_file_id_to_token_count: dict[str, int | None] = {}
|
||||
for user_file_id in updatable_ids:
|
||||
contents = content_by_file.get(user_file_id)
|
||||
if contents:
|
||||
combined_content = " ".join(contents)
|
||||
user_file_chunks = [
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == user_file_id
|
||||
]
|
||||
if user_file_chunks:
|
||||
combined_content = " ".join(
|
||||
[chunk.content for chunk in user_file_chunks]
|
||||
)
|
||||
user_file_id_to_raw_text[str(user_file_id)] = combined_content
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
|
||||
@@ -174,16 +181,28 @@ class UserFileIndexingAdapter:
|
||||
user_file_id_to_raw_text[str(user_file_id)] = ""
|
||||
user_file_id_to_token_count[str(user_file_id)] = None
|
||||
|
||||
return UserFileChunkEnricher(
|
||||
user_file_id_to_access=user_file_id_to_access,
|
||||
user_file_id_to_project_ids=user_file_id_to_project_ids,
|
||||
user_file_id_to_persona_ids=user_file_id_to_persona_ids,
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=user_file_id_to_access.get(chunk.source_document.id, no_access),
|
||||
document_sets=set(),
|
||||
user_project=user_file_id_to_project_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
|
||||
boost=DEFAULT_BOOST,
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
)
|
||||
for chunk_num, chunk in enumerate(chunks_with_embeddings)
|
||||
]
|
||||
|
||||
return BuildMetadataAwareChunksResult(
|
||||
chunks=access_aware_chunks,
|
||||
doc_id_to_previous_chunk_cnt=user_file_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
|
||||
doc_id_to_new_chunk_cnt=user_file_id_to_new_chunk_cnt,
|
||||
user_file_id_to_raw_text=user_file_id_to_raw_text,
|
||||
user_file_id_to_token_count=user_file_id_to_token_count,
|
||||
no_access=no_access,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
def _notify_assistant_owners_if_files_ready(
|
||||
@@ -227,9 +246,8 @@ class UserFileIndexingAdapter:
|
||||
context: DocumentBatchPrepareContext,
|
||||
updatable_chunk_data: list[UpdatableChunkData], # noqa: ARG002
|
||||
filtered_documents: list[Document], # noqa: ARG002
|
||||
enrichment: ChunkEnrichmentContext,
|
||||
result: BuildMetadataAwareChunksResult,
|
||||
) -> None:
|
||||
assert isinstance(enrichment, UserFileChunkEnricher)
|
||||
user_file_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
user_files = (
|
||||
@@ -245,10 +263,8 @@ class UserFileIndexingAdapter:
|
||||
user_file.last_project_sync_at = datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
)
|
||||
user_file.chunk_count = enrichment.doc_id_to_new_chunk_cnt.get(
|
||||
str(user_file.id), 0
|
||||
)
|
||||
user_file.token_count = enrichment.user_file_id_to_token_count[
|
||||
user_file.chunk_count = result.doc_id_to_new_chunk_cnt[str(user_file.id)]
|
||||
user_file.token_count = result.user_file_id_to_token_count[
|
||||
str(user_file.id)
|
||||
]
|
||||
|
||||
@@ -260,54 +276,8 @@ class UserFileIndexingAdapter:
|
||||
# Store the plaintext in the file store for faster retrieval
|
||||
# NOTE: this creates its own session to avoid committing the overall
|
||||
# transaction.
|
||||
for user_file_id, raw_text in enrichment.user_file_id_to_raw_text.items():
|
||||
for user_file_id, raw_text in result.user_file_id_to_raw_text.items():
|
||||
store_user_file_plaintext(
|
||||
user_file_id=UUID(user_file_id),
|
||||
plaintext_content=raw_text,
|
||||
)
|
||||
|
||||
|
||||
class UserFileChunkEnricher:
|
||||
"""Pre-computed metadata for per-chunk enrichment of user-uploaded files."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_file_id_to_access: dict[str, DocumentAccess],
|
||||
user_file_id_to_project_ids: dict[str, list[int]],
|
||||
user_file_id_to_persona_ids: dict[str, list[int]],
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int],
|
||||
doc_id_to_new_chunk_cnt: dict[str, int],
|
||||
user_file_id_to_raw_text: dict[str, str],
|
||||
user_file_id_to_token_count: dict[str, int | None],
|
||||
no_access: DocumentAccess,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
self._user_file_id_to_access = user_file_id_to_access
|
||||
self._user_file_id_to_project_ids = user_file_id_to_project_ids
|
||||
self._user_file_id_to_persona_ids = user_file_id_to_persona_ids
|
||||
self._no_access = no_access
|
||||
self._tenant_id = tenant_id
|
||||
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
|
||||
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
|
||||
self.user_file_id_to_raw_text = user_file_id_to_raw_text
|
||||
self.user_file_id_to_token_count = user_file_id_to_token_count
|
||||
|
||||
def enrich_chunk(
|
||||
self, chunk: IndexChunk, score: float
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=self._user_file_id_to_access.get(
|
||||
chunk.source_document.id, self._no_access
|
||||
),
|
||||
document_sets=set(),
|
||||
user_project=self._user_file_id_to_project_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
personas=self._user_file_id_to_persona_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
boost=DEFAULT_BOOST,
|
||||
tenant_id=self._tenant_id,
|
||||
aggregated_chunk_boost_factor=score,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
from typing import cast
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -49,7 +47,6 @@ from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import embed_chunks_with_failure_handling
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexingBatchAdapter
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
@@ -94,15 +91,6 @@ class IndexingPipelineResult(BaseModel):
|
||||
|
||||
failures: list[ConnectorFailure]
|
||||
|
||||
@classmethod
|
||||
def empty(cls, total_docs: int) -> "IndexingPipelineResult":
|
||||
return cls(
|
||||
new_docs=0,
|
||||
total_docs=total_docs,
|
||||
total_chunks=0,
|
||||
failures=[],
|
||||
)
|
||||
|
||||
|
||||
class IndexingPipelineProtocol(Protocol):
|
||||
def __call__(
|
||||
@@ -684,7 +672,12 @@ def index_doc_batch(
|
||||
filtered_documents = filter_fnc(document_batch)
|
||||
context = adapter.prepare(filtered_documents, ignore_time_skip)
|
||||
if not context:
|
||||
return IndexingPipelineResult.empty(len(filtered_documents))
|
||||
return IndexingPipelineResult(
|
||||
new_docs=0,
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=0,
|
||||
failures=[],
|
||||
)
|
||||
|
||||
# Convert documents to IndexingDocument objects with processed section
|
||||
# logger.debug("Processing image sections")
|
||||
@@ -755,29 +748,19 @@ def index_doc_batch(
|
||||
# we still write data here for the immediate and most likely correct sync, but
|
||||
# to resolve this, an update of the last modified field at the end of this loop
|
||||
# always triggers a final metadata sync via the celery queue
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=chunks_with_embeddings,
|
||||
chunk_content_scores=chunk_content_scores,
|
||||
tenant_id=tenant_id,
|
||||
chunks=cast(list[DocAwareChunk], chunks_with_embeddings),
|
||||
context=context,
|
||||
)
|
||||
|
||||
metadata_aware_chunks = [
|
||||
enricher.enrich_chunk(chunk, score)
|
||||
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
|
||||
]
|
||||
|
||||
short_descriptor_list = [
|
||||
chunk.to_short_descriptor() for chunk in metadata_aware_chunks
|
||||
]
|
||||
short_descriptor_list = [chunk.to_short_descriptor() for chunk in result.chunks]
|
||||
short_descriptor_log = str(short_descriptor_list)[:1024]
|
||||
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
|
||||
|
||||
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = None
|
||||
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = None
|
||||
|
||||
def chunk_iterable_creator() -> Iterable[DocMetadataAwareIndexChunk]:
|
||||
return metadata_aware_chunks
|
||||
|
||||
for document_index in document_indices:
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
@@ -787,10 +770,10 @@ def index_doc_batch(
|
||||
vector_db_write_failures,
|
||||
) = write_chunks_to_vector_db_with_backoff(
|
||||
document_index=document_index,
|
||||
make_chunks=chunk_iterable_creator,
|
||||
chunks=result.chunks,
|
||||
index_batch_params=IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=enricher.doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=enricher.doc_id_to_new_chunk_cnt,
|
||||
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=chunker.enable_large_chunks,
|
||||
),
|
||||
@@ -819,7 +802,7 @@ def index_doc_batch(
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Returned IDs: {all_returned_doc_ids}. "
|
||||
"This should never happen."
|
||||
f"This occurred for document index {document_index.__class__.__name__}"
|
||||
f"This occured for document index {document_index.__class__.__name__}"
|
||||
)
|
||||
# We treat the first document index we got as the primary one used
|
||||
# for reporting the state of indexing.
|
||||
@@ -832,7 +815,7 @@ def index_doc_batch(
|
||||
context=context,
|
||||
updatable_chunk_data=updatable_chunk_data,
|
||||
filtered_documents=filtered_documents,
|
||||
enrichment=enricher,
|
||||
result=result,
|
||||
)
|
||||
|
||||
assert primary_doc_idx_insertion_records is not None
|
||||
|
||||
@@ -235,16 +235,12 @@ class UpdatableChunkData(BaseModel):
|
||||
boost_score: float
|
||||
|
||||
|
||||
class ChunkEnrichmentContext(Protocol):
|
||||
"""Returned by prepare_enrichment. Holds pre-computed metadata lookups
|
||||
and provides per-chunk enrichment."""
|
||||
|
||||
class BuildMetadataAwareChunksResult(BaseModel):
|
||||
chunks: list[DocMetadataAwareIndexChunk]
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int]
|
||||
doc_id_to_new_chunk_cnt: dict[str, int]
|
||||
|
||||
def enrich_chunk(
|
||||
self, chunk: IndexChunk, score: float
|
||||
) -> DocMetadataAwareIndexChunk: ...
|
||||
user_file_id_to_raw_text: dict[str, str]
|
||||
user_file_id_to_token_count: dict[str, int | None]
|
||||
|
||||
|
||||
class IndexingBatchAdapter(Protocol):
|
||||
@@ -258,24 +254,18 @@ class IndexingBatchAdapter(Protocol):
|
||||
) -> Generator[TransactionalContext, None, None]:
|
||||
"""Provide a transaction/row-lock context for critical updates."""
|
||||
|
||||
def prepare_enrichment(
|
||||
def build_metadata_aware_chunks(
|
||||
self,
|
||||
context: "DocumentBatchPrepareContext",
|
||||
chunks_with_embeddings: list[IndexChunk],
|
||||
chunk_content_scores: list[float],
|
||||
tenant_id: str,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> ChunkEnrichmentContext:
|
||||
"""Prepare per-chunk enrichment data (access, document sets, boost, etc.).
|
||||
|
||||
Precondition: ``chunks`` have already been through the embedding step
|
||||
(i.e. they are ``IndexChunk`` instances with populated embeddings,
|
||||
passed here as the base ``DocAwareChunk`` type).
|
||||
"""
|
||||
...
|
||||
context: "DocumentBatchPrepareContext",
|
||||
) -> BuildMetadataAwareChunksResult: ...
|
||||
|
||||
def post_index(
|
||||
self,
|
||||
context: "DocumentBatchPrepareContext",
|
||||
updatable_chunk_data: list[UpdatableChunkData],
|
||||
filtered_documents: list[Document],
|
||||
enrichment: ChunkEnrichmentContext,
|
||||
result: BuildMetadataAwareChunksResult,
|
||||
) -> None: ...
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
from itertools import groupby
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -31,22 +28,22 @@ def _log_insufficient_storage_error(e: Exception) -> None:
|
||||
|
||||
def write_chunks_to_vector_db_with_backoff(
|
||||
document_index: DocumentIndex,
|
||||
make_chunks: Callable[[], Iterable[DocMetadataAwareIndexChunk]],
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
|
||||
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
|
||||
goes document by document to isolate the failure(s).
|
||||
|
||||
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
|
||||
vector DB interface assumes that all chunks for a single document are present. The
|
||||
chunks must also be in contiguous batches
|
||||
vector DB interface assumes that all chunks for a single document are present.
|
||||
"""
|
||||
|
||||
# first try to write the chunks to the vector db
|
||||
try:
|
||||
return (
|
||||
list(
|
||||
document_index.index(
|
||||
chunks=make_chunks(),
|
||||
chunks=chunks,
|
||||
index_batch_params=index_batch_params,
|
||||
)
|
||||
),
|
||||
@@ -63,16 +60,14 @@ def write_chunks_to_vector_db_with_backoff(
|
||||
# wait a couple seconds just to give the vector db a chance to recover
|
||||
time.sleep(2)
|
||||
|
||||
# try writing each doc one by one
|
||||
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
chunks_for_docs[chunk.source_document.id].append(chunk)
|
||||
|
||||
insertion_records: list[DocumentInsertionRecord] = []
|
||||
failures: list[ConnectorFailure] = []
|
||||
|
||||
def key(chunk: DocMetadataAwareIndexChunk) -> str:
|
||||
return chunk.source_document.id
|
||||
|
||||
for doc_id, chunks_for_doc in groupby(make_chunks(), key=key):
|
||||
first_chunk = next(chunks_for_doc)
|
||||
chunks_for_doc = chain([first_chunk], chunks_for_doc)
|
||||
|
||||
for doc_id, chunks_for_doc in chunks_for_docs.items():
|
||||
try:
|
||||
insertion_records.extend(
|
||||
document_index.index(
|
||||
@@ -92,7 +87,9 @@ def write_chunks_to_vector_db_with_backoff(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=first_chunk.get_link(),
|
||||
document_link=(
|
||||
chunks_for_doc[0].get_link() if chunks_for_doc else None
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
|
||||
@@ -62,6 +62,9 @@ def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookRespo
|
||||
name=hook.name,
|
||||
hook_point=hook.hook_point,
|
||||
endpoint_url=hook.endpoint_url,
|
||||
api_key_masked=(
|
||||
hook.api_key.get_value(apply_mask=True) if hook.api_key else None
|
||||
),
|
||||
fail_strategy=hook.fail_strategy,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
is_active=hook.is_active,
|
||||
@@ -220,8 +223,8 @@ def create_hook(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
"""Create a new hook. The endpoint is validated before persisting — creation fails if
|
||||
the endpoint cannot be reached or the api_key is invalid. Hooks are created inactive;
|
||||
use POST /{hook_id}/activate once ready to receive traffic."""
|
||||
the endpoint cannot be reached or the api_key is invalid. Hooks are created active.
|
||||
"""
|
||||
spec = get_hook_point_spec(req.hook_point)
|
||||
api_key = req.api_key.get_secret_value() if req.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
@@ -240,9 +243,10 @@ def create_hook(
|
||||
api_key=api_key,
|
||||
fail_strategy=req.fail_strategy or spec.default_fail_strategy,
|
||||
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
|
||||
is_active=True,
|
||||
is_reachable=True,
|
||||
creator_id=user.id,
|
||||
)
|
||||
hook.is_reachable = True
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook, creator_email=user.email)
|
||||
|
||||
|
||||
207
backend/onyx/server/metrics/celery_task_metrics.py
Normal file
207
backend/onyx/server/metrics/celery_task_metrics.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Generic Celery task lifecycle Prometheus metrics.
|
||||
|
||||
Provides signal handlers that track task started/completed/failed counts,
|
||||
active task gauge, task duration histograms, and retry/reject/revoke counts.
|
||||
These fire for ALL tasks on the worker — no per-connector enrichment
|
||||
(see indexing_task_metrics.py for that).
|
||||
|
||||
Usage in a worker app module:
|
||||
from onyx.server.metrics.celery_task_metrics import (
|
||||
on_celery_task_prerun,
|
||||
on_celery_task_postrun,
|
||||
on_celery_task_retry,
|
||||
on_celery_task_revoked,
|
||||
on_celery_task_rejected,
|
||||
)
|
||||
# Call from the worker's existing signal handlers
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from celery import Task
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Gauge
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
TASK_STARTED = Counter(
|
||||
"onyx_celery_task_started_total",
|
||||
"Total Celery tasks started",
|
||||
["task_name", "queue"],
|
||||
)
|
||||
|
||||
TASK_COMPLETED = Counter(
|
||||
"onyx_celery_task_completed_total",
|
||||
"Total Celery tasks completed",
|
||||
["task_name", "queue", "outcome"],
|
||||
)
|
||||
|
||||
TASK_DURATION = Histogram(
|
||||
"onyx_celery_task_duration_seconds",
|
||||
"Celery task execution duration in seconds",
|
||||
["task_name", "queue"],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
)
|
||||
|
||||
TASKS_ACTIVE = Gauge(
|
||||
"onyx_celery_tasks_active",
|
||||
"Currently executing Celery tasks",
|
||||
["task_name", "queue"],
|
||||
)
|
||||
|
||||
TASK_RETRIED = Counter(
|
||||
"onyx_celery_task_retried_total",
|
||||
"Total Celery tasks retried",
|
||||
["task_name", "queue"],
|
||||
)
|
||||
|
||||
TASK_REVOKED = Counter(
|
||||
"onyx_celery_task_revoked_total",
|
||||
"Total Celery tasks revoked (cancelled)",
|
||||
["task_name"],
|
||||
)
|
||||
|
||||
TASK_REJECTED = Counter(
|
||||
"onyx_celery_task_rejected_total",
|
||||
"Total Celery tasks rejected by worker",
|
||||
["task_name"],
|
||||
)
|
||||
|
||||
# task_id → (monotonic start time, metric labels)
|
||||
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
|
||||
|
||||
# Lock protecting _task_start_times — prerun, postrun, and eviction may
|
||||
# run concurrently on thread-pool workers.
|
||||
_task_start_times_lock = threading.Lock()
|
||||
|
||||
# Entries older than this are evicted on each prerun to prevent unbounded
|
||||
# growth when tasks are killed (SIGTERM, OOM) and postrun never fires.
|
||||
_MAX_START_TIME_AGE_SECONDS = 3600 # 1 hour
|
||||
|
||||
|
||||
def _evict_stale_start_times() -> None:
|
||||
"""Remove _task_start_times entries older than _MAX_START_TIME_AGE_SECONDS.
|
||||
|
||||
Must be called while holding _task_start_times_lock.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
stale_ids = [
|
||||
tid
|
||||
for tid, (start, _labels) in _task_start_times.items()
|
||||
if now - start > _MAX_START_TIME_AGE_SECONDS
|
||||
]
|
||||
for tid in stale_ids:
|
||||
entry = _task_start_times.pop(tid, None)
|
||||
if entry is not None:
|
||||
_labels = entry[1]
|
||||
# Decrement active gauge for evicted tasks — these tasks were
|
||||
# started but never completed (killed, OOM, etc.).
|
||||
active_gauge = TASKS_ACTIVE.labels(**_labels)
|
||||
if active_gauge._value.get() > 0:
|
||||
active_gauge.dec()
|
||||
|
||||
|
||||
def _get_task_labels(task: Task) -> dict[str, str]:
|
||||
"""Extract task_name and queue labels from a Celery Task instance."""
|
||||
task_name = task.name or "unknown"
|
||||
queue = "unknown"
|
||||
try:
|
||||
delivery_info = task.request.delivery_info
|
||||
if delivery_info:
|
||||
queue = delivery_info.get("routing_key") or "unknown"
|
||||
except AttributeError:
|
||||
pass
|
||||
return {"task_name": task_name, "queue": queue}
|
||||
|
||||
|
||||
def on_celery_task_prerun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
) -> None:
|
||||
"""Record task start. Call from the worker's task_prerun signal handler."""
|
||||
if task is None or task_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
labels = _get_task_labels(task)
|
||||
TASK_STARTED.labels(**labels).inc()
|
||||
TASKS_ACTIVE.labels(**labels).inc()
|
||||
with _task_start_times_lock:
|
||||
_evict_stale_start_times()
|
||||
_task_start_times[task_id] = (time.monotonic(), labels)
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task prerun metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_postrun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
state: str | None,
|
||||
) -> None:
|
||||
"""Record task completion. Call from the worker's task_postrun signal handler."""
|
||||
if task is None or task_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
labels = _get_task_labels(task)
|
||||
outcome = "success" if state == "SUCCESS" else "failure"
|
||||
TASK_COMPLETED.labels(**labels, outcome=outcome).inc()
|
||||
|
||||
# Guard against going below 0 if postrun fires without a matching
|
||||
# prerun (e.g. after a worker restart or stale entry eviction).
|
||||
active_gauge = TASKS_ACTIVE.labels(**labels)
|
||||
if active_gauge._value.get() > 0:
|
||||
active_gauge.dec()
|
||||
|
||||
with _task_start_times_lock:
|
||||
entry = _task_start_times.pop(task_id, None)
|
||||
if entry is not None:
|
||||
start_time, _stored_labels = entry
|
||||
TASK_DURATION.labels(**labels).observe(time.monotonic() - start_time)
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task postrun metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_retry(
|
||||
_task_id: str | None,
|
||||
task: Task | None,
|
||||
) -> None:
|
||||
"""Record task retry. Call from the worker's task_retry signal handler."""
|
||||
if task is None:
|
||||
return
|
||||
try:
|
||||
labels = _get_task_labels(task)
|
||||
TASK_RETRIED.labels(**labels).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task retry metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_revoked(
|
||||
_task_id: str | None,
|
||||
task_name: str | None = None,
|
||||
) -> None:
|
||||
"""Record task revocation. The revoked signal doesn't provide a Task
|
||||
instance, only the task name via sender."""
|
||||
if task_name is None:
|
||||
return
|
||||
try:
|
||||
TASK_REVOKED.labels(task_name=task_name).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task revoked metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_rejected(
|
||||
_task_id: str | None,
|
||||
task_name: str | None = None,
|
||||
) -> None:
|
||||
"""Record task rejection."""
|
||||
if task_name is None:
|
||||
return
|
||||
try:
|
||||
TASK_REJECTED.labels(task_name=task_name).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task rejected metrics", exc_info=True)
|
||||
528
backend/onyx/server/metrics/indexing_pipeline.py
Normal file
528
backend/onyx/server/metrics/indexing_pipeline.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""Prometheus collectors for Celery queue depths and indexing pipeline state.
|
||||
|
||||
These collectors query Redis and Postgres at scrape time (the Collector pattern),
|
||||
so metrics are always fresh when Prometheus scrapes /metrics. They run inside the
|
||||
monitoring celery worker which already has Redis and DB access.
|
||||
|
||||
To avoid hammering Redis/Postgres on every 15s scrape, results are cached with
|
||||
a configurable TTL (default 30s). This means metrics may be up to TTL seconds
|
||||
stale, which is fine for monitoring dashboards.
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from prometheus_client.core import GaugeMetricFamily
|
||||
from prometheus_client.registry import Collector
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Default cache TTL in seconds. Scrapes hitting within this window return
|
||||
# the previous result without re-querying Redis/Postgres.
|
||||
_DEFAULT_CACHE_TTL = 30.0
|
||||
|
||||
_QUEUE_LABEL_MAP: dict[str, str] = {
|
||||
OnyxCeleryQueues.PRIMARY: "primary",
|
||||
OnyxCeleryQueues.DOCPROCESSING: "docprocessing",
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING: "docfetching",
|
||||
OnyxCeleryQueues.VESPA_METADATA_SYNC: "vespa_metadata_sync",
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION: "connector_deletion",
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING: "connector_pruning",
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC: "permissions_sync",
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC: "external_group_sync",
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT: "permissions_upsert",
|
||||
OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING: "hierarchy_fetching",
|
||||
OnyxCeleryQueues.LLM_MODEL_UPDATE: "llm_model_update",
|
||||
OnyxCeleryQueues.CHECKPOINT_CLEANUP: "checkpoint_cleanup",
|
||||
OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP: "index_attempt_cleanup",
|
||||
OnyxCeleryQueues.CSV_GENERATION: "csv_generation",
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING: "user_file_processing",
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC: "user_file_project_sync",
|
||||
OnyxCeleryQueues.USER_FILE_DELETE: "user_file_delete",
|
||||
OnyxCeleryQueues.MONITORING: "monitoring",
|
||||
OnyxCeleryQueues.SANDBOX: "sandbox",
|
||||
OnyxCeleryQueues.OPENSEARCH_MIGRATION: "opensearch_migration",
|
||||
}
|
||||
|
||||
# Queues where prefetched (unacked) task counts are meaningful
|
||||
_UNACKED_QUEUES: list[str] = [
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
OnyxCeleryQueues.DOCPROCESSING,
|
||||
]
|
||||
|
||||
|
||||
class _CachedCollector(Collector):
|
||||
"""Base collector with TTL-based caching.
|
||||
|
||||
Subclasses implement ``_collect_fresh()`` to query the actual data source.
|
||||
The base ``collect()`` returns cached results if the TTL hasn't expired,
|
||||
avoiding repeated queries when Prometheus scrapes frequently.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
self._cache_ttl = cache_ttl
|
||||
self._cached_result: list[GaugeMetricFamily] | None = None
|
||||
self._last_collect_time: float = 0.0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def collect(self) -> list[GaugeMetricFamily]:
|
||||
with self._lock:
|
||||
now = time.monotonic()
|
||||
if (
|
||||
now - self._last_collect_time < self._cache_ttl
|
||||
and self._cached_result is not None
|
||||
):
|
||||
return self._cached_result
|
||||
|
||||
try:
|
||||
result = self._collect_fresh()
|
||||
self._cached_result = result
|
||||
self._last_collect_time = now
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception(f"Error in {type(self).__name__}.collect()")
|
||||
# Return stale cache on error rather than nothing — avoids
|
||||
# metrics disappearing during transient failures.
|
||||
return self._cached_result if self._cached_result is not None else []
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
raise NotImplementedError
|
||||
|
||||
def describe(self) -> list[GaugeMetricFamily]:
|
||||
return []
|
||||
|
||||
|
||||
class QueueDepthCollector(_CachedCollector):
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape.
|
||||
|
||||
Uses a Redis client factory (callable) rather than a stored client
|
||||
reference so the connection is always fresh from Celery's pool.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
|
||||
depth = GaugeMetricFamily(
|
||||
"onyx_queue_depth",
|
||||
"Number of tasks waiting in Celery queue",
|
||||
labels=["queue"],
|
||||
)
|
||||
unacked = GaugeMetricFamily(
|
||||
"onyx_queue_unacked",
|
||||
"Number of prefetched (unacked) tasks for queue",
|
||||
labels=["queue"],
|
||||
)
|
||||
queue_age = GaugeMetricFamily(
|
||||
"onyx_queue_oldest_task_age_seconds",
|
||||
"Age of the oldest task in the queue (seconds since enqueue)",
|
||||
labels=["queue"],
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
|
||||
for queue_name, label in _QUEUE_LABEL_MAP.items():
|
||||
length = celery_get_queue_length(queue_name, redis_client)
|
||||
depth.add_metric([label], length)
|
||||
|
||||
# Peek at the oldest message to get its age
|
||||
if length > 0:
|
||||
age = self._get_oldest_message_age(redis_client, queue_name, now)
|
||||
if age is not None:
|
||||
queue_age.add_metric([label], age)
|
||||
|
||||
for queue_name in _UNACKED_QUEUES:
|
||||
label = _QUEUE_LABEL_MAP[queue_name]
|
||||
task_ids = celery_get_unacked_task_ids(queue_name, redis_client)
|
||||
unacked.add_metric([label], len(task_ids))
|
||||
|
||||
return [depth, unacked, queue_age]
|
||||
|
||||
@staticmethod
|
||||
def _get_oldest_message_age(
|
||||
redis_client: Redis, queue_name: str, now: float
|
||||
) -> float | None:
|
||||
"""Peek at the oldest (tail) message in a Redis list queue
|
||||
and extract its timestamp to compute age.
|
||||
|
||||
Note: If the Celery message contains neither ``properties.timestamp``
|
||||
nor ``headers.timestamp``, no age metric is emitted for this queue.
|
||||
This can happen with custom task producers or non-standard Celery
|
||||
protocol versions. The metric will simply be absent rather than
|
||||
inaccurate, which is the safest behavior for alerting.
|
||||
"""
|
||||
try:
|
||||
raw: bytes | str | None = redis_client.lindex(queue_name, -1) # type: ignore[assignment]
|
||||
if raw is None:
|
||||
return None
|
||||
msg = json.loads(raw)
|
||||
# Check for ETA tasks first — they are intentionally delayed,
|
||||
# so reporting their queue age would be misleading.
|
||||
headers = msg.get("headers", {})
|
||||
if headers.get("eta") is not None:
|
||||
return None
|
||||
# Celery v2 protocol: timestamp in properties
|
||||
props = msg.get("properties", {})
|
||||
ts = props.get("timestamp")
|
||||
if ts is not None:
|
||||
return now - float(ts)
|
||||
# Fallback: some Celery configurations place the timestamp in
|
||||
# headers instead of properties.
|
||||
ts = headers.get("timestamp")
|
||||
if ts is not None:
|
||||
return now - float(ts)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
class IndexAttemptCollector(_CachedCollector):
|
||||
"""Queries Postgres for index attempt state on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._configured: bool = False
|
||||
self._terminal_statuses: list = []
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Call once DB engine is initialized."""
|
||||
from onyx.db.enums import IndexingStatus
|
||||
|
||||
self._terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
self._configured = True
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if not self._configured:
|
||||
return []
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.index_attempt import get_active_index_attempts_for_metrics
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
attempts_gauge = GaugeMetricFamily(
|
||||
"onyx_index_attempts_active",
|
||||
"Number of non-terminal index attempts",
|
||||
labels=[
|
||||
"status",
|
||||
"source",
|
||||
"tenant_id",
|
||||
"connector_name",
|
||||
"cc_pair_id",
|
||||
],
|
||||
)
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tid in tenant_ids:
|
||||
# Defensive guard — get_all_tenant_ids() should never yield None,
|
||||
# but we guard here for API stability in case the contract changes.
|
||||
if tid is None:
|
||||
continue
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
rows = get_active_index_attempts_for_metrics(session)
|
||||
|
||||
for status, source, cc_id, cc_name, count in rows:
|
||||
name_val = cc_name or f"cc_pair_{cc_id}"
|
||||
attempts_gauge.add_metric(
|
||||
[
|
||||
status.value,
|
||||
source.value,
|
||||
tid,
|
||||
name_val,
|
||||
str(cc_id),
|
||||
],
|
||||
count,
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return [attempts_gauge]
|
||||
|
||||
|
||||
class ConnectorHealthCollector(_CachedCollector):
|
||||
"""Queries Postgres for connector health state on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._configured: bool = False
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Call once DB engine is initialized."""
|
||||
self._configured = True
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if not self._configured:
|
||||
return []
|
||||
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_health_for_metrics,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.index_attempt import get_docs_indexed_by_cc_pair
|
||||
from onyx.db.index_attempt import get_failed_attempt_counts_by_cc_pair
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
staleness_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_last_success_age_seconds",
|
||||
"Seconds since last successful index for this connector",
|
||||
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
|
||||
)
|
||||
error_state_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_in_error_state",
|
||||
"Whether the connector is in a repeated error state (1=yes, 0=no)",
|
||||
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
|
||||
)
|
||||
by_status_gauge = GaugeMetricFamily(
|
||||
"onyx_connectors_by_status",
|
||||
"Number of connectors grouped by status",
|
||||
labels=["tenant_id", "status"],
|
||||
)
|
||||
error_total_gauge = GaugeMetricFamily(
|
||||
"onyx_connectors_in_error_total",
|
||||
"Total number of connectors in repeated error state",
|
||||
labels=["tenant_id"],
|
||||
)
|
||||
per_connector_labels = [
|
||||
"tenant_id",
|
||||
"source",
|
||||
"cc_pair_id",
|
||||
"connector_name",
|
||||
]
|
||||
docs_success_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_docs_indexed",
|
||||
"Total new documents indexed (90-day rolling sum) per connector",
|
||||
labels=per_connector_labels,
|
||||
)
|
||||
docs_error_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_error_count",
|
||||
"Total number of failed index attempts per connector",
|
||||
labels=per_connector_labels,
|
||||
)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tid in tenant_ids:
|
||||
# Defensive guard — get_all_tenant_ids() should never yield None,
|
||||
# but we guard here for API stability in case the contract changes.
|
||||
if tid is None:
|
||||
continue
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
pairs = get_connector_health_for_metrics(session)
|
||||
error_counts_by_cc = get_failed_attempt_counts_by_cc_pair(session)
|
||||
docs_by_cc = get_docs_indexed_by_cc_pair(session)
|
||||
|
||||
status_counts: dict[str, int] = {}
|
||||
error_count = 0
|
||||
|
||||
for (
|
||||
cc_id,
|
||||
status,
|
||||
in_error,
|
||||
last_success,
|
||||
cc_name,
|
||||
source,
|
||||
) in pairs:
|
||||
cc_id_str = str(cc_id)
|
||||
source_val = source.value
|
||||
name_val = cc_name or f"cc_pair_{cc_id}"
|
||||
label_vals = [tid, source_val, cc_id_str, name_val]
|
||||
|
||||
if last_success is not None:
|
||||
# Both `now` and `last_success` are timezone-aware
|
||||
# (the DB column uses DateTime(timezone=True)),
|
||||
# so subtraction is safe.
|
||||
age = (now - last_success).total_seconds()
|
||||
staleness_gauge.add_metric(label_vals, age)
|
||||
|
||||
error_state_gauge.add_metric(
|
||||
label_vals,
|
||||
1.0 if in_error else 0.0,
|
||||
)
|
||||
if in_error:
|
||||
error_count += 1
|
||||
|
||||
docs_success_gauge.add_metric(
|
||||
label_vals,
|
||||
docs_by_cc.get(cc_id, 0),
|
||||
)
|
||||
|
||||
docs_error_gauge.add_metric(
|
||||
label_vals,
|
||||
error_counts_by_cc.get(cc_id, 0),
|
||||
)
|
||||
|
||||
status_val = status.value
|
||||
status_counts[status_val] = status_counts.get(status_val, 0) + 1
|
||||
|
||||
for status_val, count in status_counts.items():
|
||||
by_status_gauge.add_metric([tid, status_val], count)
|
||||
|
||||
error_total_gauge.add_metric([tid], error_count)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return [
|
||||
staleness_gauge,
|
||||
error_state_gauge,
|
||||
by_status_gauge,
|
||||
error_total_gauge,
|
||||
docs_success_gauge,
|
||||
docs_error_gauge,
|
||||
]
|
||||
|
||||
|
||||
class RedisHealthCollector(_CachedCollector):
|
||||
"""Collects Redis server health metrics (memory, clients, etc.)."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
|
||||
memory_used = GaugeMetricFamily(
|
||||
"onyx_redis_memory_used_bytes",
|
||||
"Redis used memory in bytes",
|
||||
)
|
||||
memory_peak = GaugeMetricFamily(
|
||||
"onyx_redis_memory_peak_bytes",
|
||||
"Redis peak used memory in bytes",
|
||||
)
|
||||
memory_frag = GaugeMetricFamily(
|
||||
"onyx_redis_memory_fragmentation_ratio",
|
||||
"Redis memory fragmentation ratio (>1.5 indicates fragmentation)",
|
||||
)
|
||||
connected_clients = GaugeMetricFamily(
|
||||
"onyx_redis_connected_clients",
|
||||
"Number of connected Redis clients",
|
||||
)
|
||||
|
||||
try:
|
||||
mem_info: dict = redis_client.info("memory") # type: ignore[assignment]
|
||||
memory_used.add_metric([], mem_info.get("used_memory", 0))
|
||||
memory_peak.add_metric([], mem_info.get("used_memory_peak", 0))
|
||||
frag = mem_info.get("mem_fragmentation_ratio")
|
||||
if frag is not None:
|
||||
memory_frag.add_metric([], frag)
|
||||
|
||||
client_info: dict = redis_client.info("clients") # type: ignore[assignment]
|
||||
connected_clients.add_metric([], client_info.get("connected_clients", 0))
|
||||
except Exception:
|
||||
logger.debug("Failed to collect Redis health metrics", exc_info=True)
|
||||
|
||||
return [memory_used, memory_peak, memory_frag, connected_clients]
|
||||
|
||||
|
||||
class WorkerHealthCollector(_CachedCollector):
|
||||
"""Collects Celery worker count and process count via inspect ping.
|
||||
|
||||
Uses a longer cache TTL (60s) since inspect.ping() is a broadcast
|
||||
command that takes a couple seconds to complete.
|
||||
|
||||
Maintains a set of known worker short-names so that when a worker
|
||||
stops responding, we emit ``up=0`` instead of silently dropping the
|
||||
metric (which would make ``absent()``-style alerts impossible).
|
||||
"""
|
||||
|
||||
# Remove a worker from _known_workers after this many consecutive
|
||||
# missed pings (at 60s TTL ≈ 10 minutes of being unreachable).
|
||||
_MAX_CONSECUTIVE_MISSES = 10
|
||||
|
||||
def __init__(self, cache_ttl: float = 60.0) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._celery_app: Any | None = None
|
||||
# worker short-name → consecutive miss count.
|
||||
# Workers start at 0 and reset to 0 each time they respond.
|
||||
# Removed after _MAX_CONSECUTIVE_MISSES missed collects.
|
||||
self._known_workers: dict[str, int] = {}
|
||||
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app instance for inspect commands."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
active_workers = GaugeMetricFamily(
|
||||
"onyx_celery_active_worker_count",
|
||||
"Number of active Celery workers responding to ping",
|
||||
)
|
||||
worker_up = GaugeMetricFamily(
|
||||
"onyx_celery_worker_up",
|
||||
"Whether a specific Celery worker is alive (1=up, 0=down)",
|
||||
labels=["worker"],
|
||||
)
|
||||
|
||||
try:
|
||||
inspector = self._celery_app.control.inspect(timeout=3.0)
|
||||
ping_result = inspector.ping()
|
||||
|
||||
responding: set[str] = set()
|
||||
if ping_result:
|
||||
active_workers.add_metric([], len(ping_result))
|
||||
for worker_name in ping_result:
|
||||
# Strip hostname suffix for cleaner labels
|
||||
short_name = worker_name.split("@")[0]
|
||||
responding.add(short_name)
|
||||
else:
|
||||
active_workers.add_metric([], 0)
|
||||
|
||||
# Register newly-seen workers and reset miss count for
|
||||
# workers that responded.
|
||||
for short_name in responding:
|
||||
self._known_workers[short_name] = 0
|
||||
|
||||
# Increment miss count for non-responding workers and evict
|
||||
# those that have been missing too long.
|
||||
stale = []
|
||||
for short_name in list(self._known_workers):
|
||||
if short_name not in responding:
|
||||
self._known_workers[short_name] += 1
|
||||
if self._known_workers[short_name] >= self._MAX_CONSECUTIVE_MISSES:
|
||||
stale.append(short_name)
|
||||
for short_name in stale:
|
||||
del self._known_workers[short_name]
|
||||
|
||||
for short_name in sorted(self._known_workers):
|
||||
worker_up.add_metric([short_name], 1 if short_name in responding else 0)
|
||||
except Exception:
|
||||
logger.debug("Failed to collect worker health metrics", exc_info=True)
|
||||
|
||||
return [active_workers, worker_up]
|
||||
113
backend/onyx/server/metrics/indexing_pipeline_setup.py
Normal file
113
backend/onyx/server/metrics/indexing_pipeline_setup.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Setup function for indexing pipeline Prometheus collectors.
|
||||
|
||||
Called once by the monitoring celery worker after Redis and DB are ready.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from prometheus_client.registry import REGISTRY
|
||||
from redis import Redis
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Module-level singletons — these are lightweight objects (no connections or DB
|
||||
# state) until configure() / set_redis_factory() is called. Keeping them at
|
||||
# module level ensures they survive the lifetime of the worker process and are
|
||||
# only registered with the Prometheus registry once.
|
||||
_queue_collector = QueueDepthCollector()
|
||||
_attempt_collector = IndexAttemptCollector()
|
||||
_connector_collector = ConnectorHealthCollector()
|
||||
_redis_health_collector = RedisHealthCollector()
|
||||
_worker_health_collector = WorkerHealthCollector()
|
||||
|
||||
|
||||
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
|
||||
"""Create a factory that returns a cached broker Redis client.
|
||||
|
||||
Reuses a single connection across scrapes to avoid leaking connections.
|
||||
Reconnects automatically if the cached connection becomes stale.
|
||||
"""
|
||||
_cached_client: list[Redis | None] = [None]
|
||||
# Keep a reference to the Kombu Connection so we can close it on
|
||||
# reconnect (the raw Redis client outlives the Kombu wrapper).
|
||||
_cached_kombu_conn: list[Any] = [None]
|
||||
|
||||
def _close_client(client: Redis) -> None:
|
||||
"""Best-effort close of a Redis client."""
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close stale Redis client", exc_info=True)
|
||||
|
||||
def _close_kombu_conn() -> None:
|
||||
"""Best-effort close of the cached Kombu Connection."""
|
||||
conn = _cached_kombu_conn[0]
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close Kombu connection", exc_info=True)
|
||||
_cached_kombu_conn[0] = None
|
||||
|
||||
def _get_broker_redis() -> Redis:
|
||||
client = _cached_client[0]
|
||||
if client is not None:
|
||||
try:
|
||||
client.ping()
|
||||
return client
|
||||
except Exception:
|
||||
logger.debug("Cached Redis client stale, reconnecting")
|
||||
_close_client(client)
|
||||
_cached_client[0] = None
|
||||
_close_kombu_conn()
|
||||
|
||||
# Get a fresh Redis client from the broker connection.
|
||||
# We hold this client long-term (cached above) rather than using a
|
||||
# context manager, because we need it to persist across scrapes.
|
||||
# The caching logic above ensures we only ever hold one connection,
|
||||
# and we close it explicitly on reconnect.
|
||||
conn = celery_app.broker_connection()
|
||||
# kombu's Channel exposes .client at runtime (the underlying Redis
|
||||
# client) but the type stubs don't declare it.
|
||||
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
|
||||
_cached_client[0] = new_client
|
||||
_cached_kombu_conn[0] = conn
|
||||
return new_client
|
||||
|
||||
return _get_broker_redis
|
||||
|
||||
|
||||
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
"""Register all indexing pipeline collectors with the default registry.
|
||||
|
||||
Args:
|
||||
celery_app: The Celery application instance. Used to obtain a fresh
|
||||
broker Redis client on each scrape for queue depth metrics.
|
||||
"""
|
||||
redis_factory = _make_broker_redis_factory(celery_app)
|
||||
_queue_collector.set_redis_factory(redis_factory)
|
||||
_redis_health_collector.set_redis_factory(redis_factory)
|
||||
_worker_health_collector.set_celery_app(celery_app)
|
||||
_attempt_collector.configure()
|
||||
_connector_collector.configure()
|
||||
|
||||
for collector in (
|
||||
_queue_collector,
|
||||
_attempt_collector,
|
||||
_connector_collector,
|
||||
_redis_health_collector,
|
||||
_worker_health_collector,
|
||||
):
|
||||
try:
|
||||
REGISTRY.register(collector)
|
||||
except ValueError:
|
||||
logger.debug("Collector already registered: %s", type(collector).__name__)
|
||||
253
backend/onyx/server/metrics/indexing_task_metrics.py
Normal file
253
backend/onyx/server/metrics/indexing_task_metrics.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""Per-connector Prometheus metrics for indexing tasks.
|
||||
|
||||
Enriches the two primary indexing tasks (docfetching_proxy_task and
|
||||
docprocessing_task) with connector-level labels: source, tenant_id,
|
||||
and cc_pair_id.
|
||||
|
||||
Note: connector_name is intentionally excluded from push-based per-task
|
||||
counters because it is a user-defined free-form string that can create
|
||||
unbounded cardinality. The pull-based collectors on the monitoring worker
|
||||
(see indexing_pipeline.py) include connector_name since they have bounded
|
||||
cardinality (one series per connector, not per task execution).
|
||||
|
||||
Uses an in-memory cache for cc_pair_id → (source, name) lookups.
|
||||
Connectors never change source type, and names change rarely, so the
|
||||
cache is safe to hold for the worker's lifetime.
|
||||
|
||||
Usage in a worker app module:
|
||||
from onyx.server.metrics.indexing_task_metrics import (
|
||||
on_indexing_task_prerun,
|
||||
on_indexing_task_postrun,
|
||||
)
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from celery import Task
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.server.metrics.celery_task_metrics import _MAX_START_TIME_AGE_SECONDS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConnectorInfo:
|
||||
"""Cached connector metadata for metric labels."""
|
||||
|
||||
source: str
|
||||
name: str
|
||||
|
||||
|
||||
_UNKNOWN_CONNECTOR = ConnectorInfo(source="unknown", name="unknown")
|
||||
|
||||
# (tenant_id, cc_pair_id) → ConnectorInfo (populated on first encounter).
|
||||
# Keyed by tenant to avoid cross-tenant cache poisoning in multi-tenant
|
||||
# deployments where different tenants can share the same cc_pair_id value.
|
||||
_connector_cache: dict[tuple[str, int], ConnectorInfo] = {}
|
||||
|
||||
# Lock protecting _connector_cache — multiple thread-pool workers may
|
||||
# resolve connectors concurrently.
|
||||
_connector_cache_lock = threading.Lock()
|
||||
|
||||
# Only enrich these task types with per-connector labels
|
||||
_INDEXING_TASK_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
}
|
||||
)
|
||||
|
||||
# connector_name is intentionally excluded — see module docstring.
|
||||
INDEXING_TASK_STARTED = Counter(
|
||||
"onyx_indexing_task_started_total",
|
||||
"Indexing tasks started per connector",
|
||||
["task_name", "source", "tenant_id", "cc_pair_id"],
|
||||
)
|
||||
|
||||
INDEXING_TASK_COMPLETED = Counter(
|
||||
"onyx_indexing_task_completed_total",
|
||||
"Indexing tasks completed per connector",
|
||||
[
|
||||
"task_name",
|
||||
"source",
|
||||
"tenant_id",
|
||||
"cc_pair_id",
|
||||
"outcome",
|
||||
],
|
||||
)
|
||||
|
||||
INDEXING_TASK_DURATION = Histogram(
|
||||
"onyx_indexing_task_duration_seconds",
|
||||
"Indexing task duration by connector type",
|
||||
["task_name", "source", "tenant_id"],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
)
|
||||
|
||||
# task_id → monotonic start time (for indexing tasks only)
|
||||
_indexing_start_times: dict[str, float] = {}
|
||||
|
||||
# Lock protecting _indexing_start_times — prerun, postrun, and eviction may
|
||||
# run concurrently on thread-pool workers.
|
||||
_indexing_start_times_lock = threading.Lock()
|
||||
|
||||
|
||||
def _evict_stale_start_times() -> None:
|
||||
"""Remove _indexing_start_times entries older than _MAX_START_TIME_AGE_SECONDS.
|
||||
|
||||
Must be called while holding _indexing_start_times_lock.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
stale_ids = [
|
||||
tid
|
||||
for tid, start in _indexing_start_times.items()
|
||||
if now - start > _MAX_START_TIME_AGE_SECONDS
|
||||
]
|
||||
for tid in stale_ids:
|
||||
_indexing_start_times.pop(tid, None)
|
||||
|
||||
|
||||
def _resolve_connector(cc_pair_id: int) -> ConnectorInfo:
|
||||
"""Resolve cc_pair_id to ConnectorInfo, using cache when possible.
|
||||
|
||||
On cache miss, does a single DB query with eager connector load.
|
||||
On any failure, returns _UNKNOWN_CONNECTOR without caching, so that
|
||||
subsequent calls can retry the lookup once the DB is available.
|
||||
|
||||
Note on tenant_id source: we read CURRENT_TENANT_ID_CONTEXTVAR for the
|
||||
cache key. The Celery tenant-aware middleware sets this contextvar before
|
||||
task execution, and it always matches kwargs["tenant_id"] (which is set
|
||||
at task dispatch time). They are guaranteed to agree for a given task
|
||||
execution context.
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get("") or ""
|
||||
cache_key = (tenant_id, cc_pair_id)
|
||||
|
||||
with _connector_cache_lock:
|
||||
cached = _connector_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pair_from_id,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session,
|
||||
cc_pair_id,
|
||||
eager_load_connector=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
# DB lookup succeeded but cc_pair doesn't exist — don't cache,
|
||||
# it may appear later (race with connector creation).
|
||||
return _UNKNOWN_CONNECTOR
|
||||
|
||||
info = ConnectorInfo(
|
||||
source=cc_pair.connector.source.value,
|
||||
name=cc_pair.name,
|
||||
)
|
||||
with _connector_cache_lock:
|
||||
_connector_cache[cache_key] = info
|
||||
return info
|
||||
except Exception:
|
||||
logger.debug(
|
||||
f"Failed to resolve connector info for cc_pair_id={cc_pair_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
return _UNKNOWN_CONNECTOR
|
||||
|
||||
|
||||
def on_indexing_task_prerun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
kwargs: dict | None,
|
||||
) -> None:
|
||||
"""Record per-connector metrics at task start.
|
||||
|
||||
Only fires for tasks in _INDEXING_TASK_NAMES. Silently returns for
|
||||
all other tasks.
|
||||
"""
|
||||
if task is None or task_id is None or kwargs is None:
|
||||
return
|
||||
|
||||
task_name = task.name or ""
|
||||
if task_name not in _INDEXING_TASK_NAMES:
|
||||
return
|
||||
|
||||
try:
|
||||
cc_pair_id = kwargs.get("cc_pair_id")
|
||||
tenant_id = str(kwargs.get("tenant_id", "unknown"))
|
||||
|
||||
if cc_pair_id is None:
|
||||
return
|
||||
|
||||
info = _resolve_connector(cc_pair_id)
|
||||
|
||||
INDEXING_TASK_STARTED.labels(
|
||||
task_name=task_name,
|
||||
source=info.source,
|
||||
tenant_id=tenant_id,
|
||||
cc_pair_id=str(cc_pair_id),
|
||||
).inc()
|
||||
|
||||
with _indexing_start_times_lock:
|
||||
_evict_stale_start_times()
|
||||
_indexing_start_times[task_id] = time.monotonic()
|
||||
except Exception:
|
||||
logger.debug("Failed to record indexing task prerun metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_indexing_task_postrun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
kwargs: dict | None,
|
||||
state: str | None,
|
||||
) -> None:
|
||||
"""Record per-connector completion metrics.
|
||||
|
||||
Only fires for tasks in _INDEXING_TASK_NAMES.
|
||||
"""
|
||||
if task is None or task_id is None or kwargs is None:
|
||||
return
|
||||
|
||||
task_name = task.name or ""
|
||||
if task_name not in _INDEXING_TASK_NAMES:
|
||||
return
|
||||
|
||||
try:
|
||||
cc_pair_id = kwargs.get("cc_pair_id")
|
||||
tenant_id = str(kwargs.get("tenant_id", "unknown"))
|
||||
|
||||
if cc_pair_id is None:
|
||||
return
|
||||
|
||||
info = _resolve_connector(cc_pair_id)
|
||||
outcome = "success" if state == "SUCCESS" else "failure"
|
||||
|
||||
INDEXING_TASK_COMPLETED.labels(
|
||||
task_name=task_name,
|
||||
source=info.source,
|
||||
tenant_id=tenant_id,
|
||||
cc_pair_id=str(cc_pair_id),
|
||||
outcome=outcome,
|
||||
).inc()
|
||||
|
||||
with _indexing_start_times_lock:
|
||||
start = _indexing_start_times.pop(task_id, None)
|
||||
if start is not None:
|
||||
INDEXING_TASK_DURATION.labels(
|
||||
task_name=task_name,
|
||||
source=info.source,
|
||||
tenant_id=tenant_id,
|
||||
).observe(time.monotonic() - start)
|
||||
except Exception:
|
||||
logger.debug("Failed to record indexing task postrun metrics", exc_info=True)
|
||||
89
backend/onyx/server/metrics/metrics_server.py
Normal file
89
backend/onyx/server/metrics/metrics_server.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Standalone Prometheus metrics HTTP server for non-API processes.
|
||||
|
||||
The FastAPI API server already exposes /metrics via prometheus-fastapi-instrumentator.
|
||||
Celery workers and other background processes use this module to expose their
|
||||
own /metrics endpoint on a configurable port.
|
||||
|
||||
Usage:
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
start_metrics_server("monitoring") # reads port from env or uses default
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Default ports for worker types that serve custom Prometheus metrics.
|
||||
# Only add entries here when a worker actually registers collectors.
|
||||
# In k8s each worker type runs in its own pod, so PROMETHEUS_METRICS_PORT
|
||||
# env var can override.
|
||||
_DEFAULT_PORTS: dict[str, int] = {
|
||||
"monitoring": 9096,
|
||||
"docfetching": 9092,
|
||||
"docprocessing": 9093,
|
||||
}
|
||||
|
||||
_server_started = False
|
||||
_server_lock = threading.Lock()
|
||||
|
||||
|
||||
def start_metrics_server(worker_type: str) -> int | None:
|
||||
"""Start a Prometheus metrics HTTP server in a background thread.
|
||||
|
||||
Returns the port if started, None if disabled or already started.
|
||||
|
||||
Port resolution order:
|
||||
1. PROMETHEUS_METRICS_PORT env var (explicit override)
|
||||
2. Default port for the worker type
|
||||
3. If worker type is unknown and no env var, skip
|
||||
|
||||
Set PROMETHEUS_METRICS_ENABLED=false to disable.
|
||||
"""
|
||||
global _server_started
|
||||
|
||||
with _server_lock:
|
||||
if _server_started:
|
||||
logger.debug(f"Metrics server already started for {worker_type}")
|
||||
return None
|
||||
|
||||
enabled = os.environ.get("PROMETHEUS_METRICS_ENABLED", "true").lower()
|
||||
if enabled in ("false", "0", "no"):
|
||||
logger.info(f"Prometheus metrics server disabled for {worker_type}")
|
||||
return None
|
||||
|
||||
port_str = os.environ.get("PROMETHEUS_METRICS_PORT")
|
||||
if port_str:
|
||||
try:
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid PROMETHEUS_METRICS_PORT '{port_str}' for {worker_type}, "
|
||||
"must be a numeric port. Skipping metrics server."
|
||||
)
|
||||
return None
|
||||
elif worker_type in _DEFAULT_PORTS:
|
||||
port = _DEFAULT_PORTS[worker_type]
|
||||
else:
|
||||
logger.info(
|
||||
f"No default metrics port for worker type '{worker_type}' "
|
||||
"and PROMETHEUS_METRICS_PORT not set. Skipping metrics server."
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
start_http_server(port)
|
||||
_server_started = True
|
||||
logger.info(
|
||||
f"Prometheus metrics server started on :{port} for {worker_type}"
|
||||
)
|
||||
return port
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
f"Failed to start metrics server on :{port} for {worker_type}: {e}"
|
||||
)
|
||||
return None
|
||||
@@ -17,6 +17,7 @@ from onyx.db.models import User
|
||||
from onyx.db.notification import dismiss_all_notifications
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.db.notification import update_notification_last_shown
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.features.build.utils import is_onyx_craft_enabled
|
||||
@@ -80,6 +81,7 @@ def fetch_settings(
|
||||
needs_reindexing=needs_reindexing,
|
||||
onyx_craft_enabled=onyx_craft_enabled_for_user,
|
||||
vector_db_enabled=not DISABLE_VECTOR_DB,
|
||||
hooks_enabled=HOOKS_AVAILABLE,
|
||||
version=onyx_version,
|
||||
)
|
||||
|
||||
|
||||
@@ -104,5 +104,7 @@ class UserSettings(Settings):
|
||||
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
|
||||
# document sets are unavailable.
|
||||
vector_db_enabled: bool = True
|
||||
# True when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
hooks_enabled: bool = False
|
||||
# Application version, read from the ONYX_VERSION env var at startup.
|
||||
version: str | None = None
|
||||
|
||||
@@ -736,7 +736,7 @@ if __name__ == "__main__":
|
||||
llm.config.model_name, llm.config.model_provider
|
||||
)
|
||||
|
||||
persona = get_default_behavior_persona(db_session)
|
||||
persona = get_default_behavior_persona(db_session, eager_load_for_tools=True)
|
||||
if persona is None:
|
||||
raise ValueError("No default persona found")
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.mcp import get_all_mcp_tools_for_server
|
||||
@@ -124,7 +125,12 @@ def construct_tools(
|
||||
) -> dict[int, list[Tool]]:
|
||||
"""Constructs tools based on persona configuration and available APIs.
|
||||
|
||||
Will simply skip tools that are not allowed/available."""
|
||||
Will simply skip tools that are not allowed/available.
|
||||
|
||||
Callers must supply a persona with ``tools``, ``document_sets``,
|
||||
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
|
||||
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
|
||||
to avoid lazy SQL queries after the session may have been flushed."""
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
# Log which tools are attached to the persona for debugging
|
||||
@@ -143,6 +149,28 @@ def construct_tools(
|
||||
# This flow is for search so we do not get all indices.
|
||||
document_index = get_default_document_index(search_settings, None, db_session)
|
||||
|
||||
def _build_search_tool(tool_id: int, config: SearchToolConfig) -> SearchTool:
|
||||
persona_search_info = PersonaSearchInfo(
|
||||
document_set_names=[ds.name for ds in persona.document_sets],
|
||||
search_start_date=persona.search_start_date,
|
||||
attached_document_ids=[doc.id for doc in persona.attached_documents],
|
||||
hierarchy_node_ids=[node.id for node in persona.hierarchy_nodes],
|
||||
)
|
||||
return SearchTool(
|
||||
tool_id=tool_id,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona_search_info=persona_search_info,
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=config.user_selected_filters,
|
||||
project_id_filter=config.project_id_filter,
|
||||
persona_id_filter=config.persona_id_filter,
|
||||
bypass_acl=config.bypass_acl,
|
||||
slack_context=config.slack_context,
|
||||
enable_slack_search=config.enable_slack_search,
|
||||
)
|
||||
|
||||
added_search_tool = False
|
||||
for db_tool_model in persona.tools:
|
||||
# If allowed_tool_ids is specified, skip tools not in the allowed list
|
||||
@@ -176,22 +204,9 @@ def construct_tools(
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
search_tool = SearchTool(
|
||||
tool_id=db_tool_model.id,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id_filter=search_tool_config.project_id_filter,
|
||||
persona_id_filter=search_tool_config.persona_id_filter,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
)
|
||||
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
tool_dict[db_tool_model.id] = [
|
||||
_build_search_tool(db_tool_model.id, search_tool_config)
|
||||
]
|
||||
|
||||
# Handle Image Generation Tool
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
@@ -421,26 +436,12 @@ def construct_tools(
|
||||
# Get the database tool model for SearchTool
|
||||
search_tool_db_model = get_builtin_tool(db_session, SearchTool)
|
||||
|
||||
# Use the passed-in config if available, otherwise create a new one
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
search_tool = SearchTool(
|
||||
tool_id=search_tool_db_model.id,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id_filter=search_tool_config.project_id_filter,
|
||||
persona_id_filter=search_tool_config.persona_id_filter,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
)
|
||||
|
||||
tool_dict[search_tool_db_model.id] = [search_tool]
|
||||
tool_dict[search_tool_db_model.id] = [
|
||||
_build_search_tool(search_tool_db_model.id, search_tool_config)
|
||||
]
|
||||
|
||||
# Always inject MemoryTool when the user has the memory tool enabled,
|
||||
# bypassing persona tool associations and allowed_tool_ids filtering
|
||||
|
||||
@@ -51,6 +51,7 @@ from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.context.search.pipeline import merge_individual_chunks
|
||||
from onyx.context.search.pipeline import search_pipeline
|
||||
@@ -65,7 +66,6 @@ from onyx.db.federated import (
|
||||
get_federated_connector_document_set_mappings_by_document_set_names,
|
||||
)
|
||||
from onyx.db.federated import list_federated_connector_oauth_tokens
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -238,8 +238,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
emitter: Emitter,
|
||||
# Used for ACLs and federated search, anonymous users only see public docs
|
||||
user: User,
|
||||
# Used for filter settings
|
||||
persona: Persona,
|
||||
# Pre-extracted persona search configuration
|
||||
persona_search_info: PersonaSearchInfo,
|
||||
llm: LLM,
|
||||
document_index: DocumentIndex,
|
||||
# Respecting user selections
|
||||
@@ -258,7 +258,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
super().__init__(emitter=emitter)
|
||||
|
||||
self.user = user
|
||||
self.persona = persona
|
||||
self.persona_search_info = persona_search_info
|
||||
self.llm = llm
|
||||
self.document_index = document_index
|
||||
self.user_selected_filters = user_selected_filters
|
||||
@@ -289,7 +289,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
# Case 1: Slack bot context — requires a Slack federated connector
|
||||
# linked via the persona's document sets
|
||||
if self.slack_context:
|
||||
document_set_names = [ds.name for ds in self.persona.document_sets]
|
||||
document_set_names = self.persona_search_info.document_set_names
|
||||
if not document_set_names:
|
||||
logger.debug(
|
||||
"Skipping Slack federated search: no document sets on persona"
|
||||
@@ -463,7 +463,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
persona_id_filter=self.persona_id_filter,
|
||||
document_index=self.document_index,
|
||||
user=self.user,
|
||||
persona=self.persona,
|
||||
persona_search_info=self.persona_search_info,
|
||||
acl_filters=acl_filters,
|
||||
embedding_model=embedding_model,
|
||||
prefetched_federated_retrieval_infos=federated_retrieval_infos,
|
||||
@@ -587,15 +587,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
and self.user_selected_filters.source_type
|
||||
else None
|
||||
)
|
||||
persona_document_sets = (
|
||||
[ds.name for ds in self.persona.document_sets] if self.persona else None
|
||||
)
|
||||
federated_retrieval_infos = (
|
||||
get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=self.user.id if self.user else None,
|
||||
source_types=prefetch_source_types,
|
||||
document_set_names=persona_document_sets,
|
||||
document_set_names=self.persona_search_info.document_set_names,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
@@ -153,13 +153,15 @@ class TestAdapterWritesBothMetadataFields:
|
||||
doc = chunk.source_document
|
||||
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
|
||||
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
context=context,
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@@ -188,13 +190,15 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
context=context,
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert project.id in aware_chunk.user_project
|
||||
assert aware_chunk.personas == []
|
||||
|
||||
@@ -225,13 +229,14 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
context=context,
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert project.id in aware_chunk.user_project
|
||||
|
||||
@@ -256,13 +261,14 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
context=context,
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert aware_chunk.personas == []
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@@ -294,11 +300,12 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
context=context,
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import UserProject
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
|
||||
|
||||
def test_eager_load_persona_loads_relationships(db_session: Session) -> None:
|
||||
"""Verify that eager_load_persona pre-loads persona, its collections, and project."""
|
||||
user = create_test_user(db_session, "eager-load")
|
||||
persona = Persona(name="eager-load-test", description="test")
|
||||
project = UserProject(name="eager-load-project", user_id=user.id)
|
||||
db_session.add_all([persona, project])
|
||||
db_session.flush()
|
||||
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="test",
|
||||
user_id=None,
|
||||
persona_id=persona.id,
|
||||
project_id=project.id,
|
||||
)
|
||||
|
||||
loaded = get_chat_session_by_id(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
|
||||
try:
|
||||
tmp = inspect(loaded)
|
||||
assert tmp is not None
|
||||
unloaded = tmp.unloaded
|
||||
assert "persona" not in unloaded
|
||||
assert "project" not in unloaded
|
||||
|
||||
tmp = inspect(loaded.persona)
|
||||
assert tmp is not None
|
||||
persona_unloaded = tmp.unloaded
|
||||
assert "tools" not in persona_unloaded
|
||||
assert "user_files" not in persona_unloaded
|
||||
assert "document_sets" not in persona_unloaded
|
||||
assert "attached_documents" not in persona_unloaded
|
||||
assert "hierarchy_nodes" not in persona_unloaded
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -6,7 +6,6 @@ These tests assume Vespa and OpenSearch are running.
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@@ -22,7 +21,6 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
|
||||
from tests.external_dependency_unit.document_index.conftest import make_chunk
|
||||
@@ -203,25 +201,3 @@ class TestDocumentIndexNew:
|
||||
assert len(result_map) == 2
|
||||
assert result_map[existing_doc] is True
|
||||
assert result_map[new_doc] is False
|
||||
|
||||
def test_index_accepts_generator(
|
||||
self,
|
||||
document_indices: list[DocumentIndexNew],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""index() accepts a generator (any iterable), not just a list."""
|
||||
for document_index in document_indices:
|
||||
doc_id = f"test_gen_{uuid.uuid4().hex[:8]}"
|
||||
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[3])
|
||||
|
||||
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for i in range(3):
|
||||
yield make_chunk(doc_id, chunk_id=i)
|
||||
|
||||
results = document_index.index(
|
||||
chunks=chunk_gen(), indexing_metadata=metadata
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].document_id == doc_id
|
||||
assert results[0].already_existed is False
|
||||
|
||||
@@ -5,7 +5,6 @@ These tests assume Vespa and OpenSearch are running.
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -167,29 +166,3 @@ class TestDocumentIndexOld:
|
||||
batch_retrieval=True,
|
||||
)
|
||||
assert len(inference_chunks) == 0
|
||||
|
||||
def test_index_accepts_generator(
|
||||
self,
|
||||
document_indices: list[DocumentIndex],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""index() accepts a generator (any iterable), not just a list."""
|
||||
for document_index in document_indices:
|
||||
|
||||
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for i in range(3):
|
||||
yield make_chunk("test_doc_gen", chunk_id=i)
|
||||
|
||||
index_batch_params = IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt={"test_doc_gen": 0},
|
||||
doc_id_to_new_chunk_cnt={"test_doc_gen": 3},
|
||||
tenant_id=get_current_tenant_id(),
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
|
||||
results = document_index.index(chunk_gen(), index_batch_params)
|
||||
|
||||
assert len(results) == 1
|
||||
record = results.pop()
|
||||
assert record.document_id == "test_doc_gen"
|
||||
assert record.already_existed is False
|
||||
|
||||
@@ -11,8 +11,8 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
@@ -139,12 +139,12 @@ def use_mock_search_pipeline(
|
||||
chunk_search_request: ChunkSearchRequest,
|
||||
document_index: DocumentIndex, # noqa: ARG001
|
||||
user: User | None, # noqa: ARG001
|
||||
persona: Persona | None, # noqa: ARG001
|
||||
persona_search_info: PersonaSearchInfo | None, # noqa: ARG001
|
||||
db_session: Session | None = None, # noqa: ARG001
|
||||
auto_detect_filters: bool = False, # noqa: ARG001
|
||||
llm: LLM | None = None, # noqa: ARG001
|
||||
project_id: int | None = None, # noqa: ARG001
|
||||
persona_id: int | None = None, # noqa: ARG001
|
||||
project_id_filter: int | None = None, # noqa: ARG001
|
||||
persona_id_filter: int | None = None, # noqa: ARG001
|
||||
# Pre-fetched data (used by SearchTool to avoid DB access in parallel calls)
|
||||
acl_filters: list[str] | None = None, # noqa: ARG001
|
||||
embedding_model: EmbeddingModel | None = None, # noqa: ARG001
|
||||
|
||||
53
backend/tests/unit/ee/onyx/db/test_user_group_rename.py
Normal file
53
backend/tests/unit/ee/onyx/db/test_user_group_rename.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Tests for user group rename DB operation."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.db.user_group import rename_user_group
|
||||
from onyx.db.models import UserGroup
|
||||
|
||||
|
||||
class TestRenameUserGroup:
|
||||
"""Tests for rename_user_group function."""
|
||||
|
||||
@patch("ee.onyx.db.user_group.DISABLE_VECTOR_DB", False)
|
||||
@patch(
|
||||
"ee.onyx.db.user_group._mark_user_group__cc_pair_relationships_outdated__no_commit"
|
||||
)
|
||||
def test_rename_succeeds_and_triggers_sync(
|
||||
self, mock_mark_outdated: MagicMock
|
||||
) -> None:
|
||||
mock_session = MagicMock()
|
||||
mock_group = MagicMock(spec=UserGroup)
|
||||
mock_group.name = "Old Name"
|
||||
mock_group.is_up_to_date = True
|
||||
mock_session.scalar.return_value = mock_group
|
||||
|
||||
result = rename_user_group(mock_session, user_group_id=1, new_name="New Name")
|
||||
|
||||
assert result.name == "New Name"
|
||||
assert result.is_up_to_date is False
|
||||
mock_mark_outdated.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_rename_group_not_found(self) -> None:
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
rename_user_group(mock_session, user_group_id=999, new_name="New Name")
|
||||
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
def test_rename_group_syncing_raises(self) -> None:
|
||||
mock_session = MagicMock()
|
||||
mock_group = MagicMock(spec=UserGroup)
|
||||
mock_group.is_up_to_date = False
|
||||
mock_session.scalar.return_value = mock_group
|
||||
|
||||
with pytest.raises(ValueError, match="currently syncing"):
|
||||
rename_user_group(mock_session, user_group_id=1, new_name="New Name")
|
||||
|
||||
mock_session.commit.assert_not_called()
|
||||
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
Unit tests for the check_available_tenants task.
|
||||
|
||||
Tests verify:
|
||||
- Provisioning loop calls pre_provision_tenant the correct number of times
|
||||
- Batch size is capped at _MAX_TENANTS_PER_RUN
|
||||
- A failure in one provisioning call does not stop subsequent calls
|
||||
- No provisioning happens when pool is already full
|
||||
- TARGET_AVAILABLE_TENANTS is respected
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.background.celery.tasks.tenant_provisioning.tasks import (
|
||||
_MAX_TENANTS_PER_RUN,
|
||||
)
|
||||
from ee.onyx.background.celery.tasks.tenant_provisioning.tasks import (
|
||||
check_available_tenants,
|
||||
)
|
||||
|
||||
# Access the underlying function directly, bypassing Celery's task wrapper
|
||||
# which injects `self` as the first argument when bind=True.
|
||||
_check_available_tenants = check_available_tenants.run
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _enable_multi_tenant(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.MULTI_TENANT",
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_redis(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.acquire.return_value = True
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.lock.return_value = mock_lock
|
||||
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.get_redis_client",
|
||||
lambda tenant_id: mock_client, # noqa: ARG005
|
||||
)
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_pre_provision(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||
mock = MagicMock(return_value=True)
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.pre_provision_tenant",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
def _mock_available_count(monkeypatch: pytest.MonkeyPatch, count: int) -> None:
|
||||
"""Set up the DB session mock to return a specific available tenant count."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=False)
|
||||
mock_session.query.return_value.count.return_value = count
|
||||
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.get_session_with_shared_schema",
|
||||
lambda: mock_session,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_multi_tenant", "mock_redis")
|
||||
class TestCheckAvailableTenants:
|
||||
def test_provisions_all_needed_tenants(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""When pool has 2 and target is 5, should provision 3."""
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
|
||||
5,
|
||||
)
|
||||
_mock_available_count(monkeypatch, 2)
|
||||
|
||||
_check_available_tenants()
|
||||
|
||||
assert mock_pre_provision.call_count == 3
|
||||
|
||||
def test_batch_capped_at_max_per_run(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""When pool needs more than _MAX_TENANTS_PER_RUN, cap the batch."""
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
|
||||
20,
|
||||
)
|
||||
_mock_available_count(monkeypatch, 0)
|
||||
|
||||
_check_available_tenants()
|
||||
|
||||
assert mock_pre_provision.call_count == _MAX_TENANTS_PER_RUN
|
||||
|
||||
def test_no_provisioning_when_pool_full(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""When pool already meets target, should not provision anything."""
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
|
||||
5,
|
||||
)
|
||||
_mock_available_count(monkeypatch, 5)
|
||||
|
||||
_check_available_tenants()
|
||||
|
||||
assert mock_pre_provision.call_count == 0
|
||||
|
||||
def test_no_provisioning_when_pool_exceeds_target(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""When pool exceeds target, should not provision anything."""
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
|
||||
5,
|
||||
)
|
||||
_mock_available_count(monkeypatch, 8)
|
||||
|
||||
_check_available_tenants()
|
||||
|
||||
assert mock_pre_provision.call_count == 0
|
||||
|
||||
def test_failure_does_not_stop_remaining(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""If one provisioning fails, the rest should still be attempted."""
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
|
||||
5,
|
||||
)
|
||||
_mock_available_count(monkeypatch, 0)
|
||||
|
||||
# Fail on calls 2 and 4 (1-indexed)
|
||||
call_count = 0
|
||||
|
||||
def side_effect() -> bool:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count in (2, 4):
|
||||
raise RuntimeError("provisioning failed")
|
||||
return True
|
||||
|
||||
mock_pre_provision.side_effect = side_effect
|
||||
|
||||
_check_available_tenants()
|
||||
|
||||
# All 5 should be attempted despite 2 failures
|
||||
assert mock_pre_provision.call_count == 5
|
||||
|
||||
def test_skips_when_not_multi_tenant(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""Should not provision when multi-tenancy is disabled."""
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.MULTI_TENANT",
|
||||
False,
|
||||
)
|
||||
|
||||
_check_available_tenants()
|
||||
|
||||
assert mock_pre_provision.call_count == 0
|
||||
|
||||
def test_skips_when_lock_not_acquired(
|
||||
self,
|
||||
mock_redis: MagicMock,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""Should skip when another instance holds the lock."""
|
||||
mock_redis.lock.return_value.acquire.return_value = False
|
||||
|
||||
_check_available_tenants()
|
||||
|
||||
assert mock_pre_provision.call_count == 0
|
||||
|
||||
def test_lock_release_failure_does_not_raise(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_redis: MagicMock,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""LockNotOwnedError on release should be caught, not propagated."""
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
|
||||
5,
|
||||
)
|
||||
_mock_available_count(monkeypatch, 4)
|
||||
|
||||
mock_redis.lock.return_value.release.side_effect = LockNotOwnedError("expired")
|
||||
|
||||
# Should not raise
|
||||
_check_available_tenants()
|
||||
|
||||
assert mock_pre_provision.call_count == 1
|
||||
@@ -1,4 +1,12 @@
|
||||
import pytest
|
||||
|
||||
from onyx.chat.process_message import _resolve_query_processing_hook_result
|
||||
from onyx.chat.process_message import remove_answer_citations
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.hooks.points.query_processing import QueryProcessingResponse
|
||||
|
||||
|
||||
def test_remove_answer_citations_strips_http_markdown_citation() -> None:
|
||||
@@ -32,3 +40,81 @@ def test_remove_answer_citations_preserves_non_citation_markdown_links() -> None
|
||||
remove_answer_citations(answer)
|
||||
== "See [reference](https://example.com/Function_(mathematics)) for context."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Query Processing hook response handling (_resolve_query_processing_hook_result)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_hook_skipped_leaves_message_text_unchanged() -> None:
|
||||
result = _resolve_query_processing_hook_result(HookSkipped(), "original query")
|
||||
assert result == "original query"
|
||||
|
||||
|
||||
def test_hook_soft_failed_leaves_message_text_unchanged() -> None:
|
||||
result = _resolve_query_processing_hook_result(HookSoftFailed(), "original query")
|
||||
assert result == "original query"
|
||||
|
||||
|
||||
def test_null_query_raises_query_rejected() -> None:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_resolve_query_processing_hook_result(
|
||||
QueryProcessingResponse(query=None), "original query"
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
|
||||
|
||||
|
||||
def test_empty_string_query_raises_query_rejected() -> None:
|
||||
"""Empty string is falsy — must be treated as rejection, same as None."""
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_resolve_query_processing_hook_result(
|
||||
QueryProcessingResponse(query=""), "original query"
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
|
||||
|
||||
|
||||
def test_whitespace_only_query_raises_query_rejected() -> None:
|
||||
"""Whitespace-only string is truthy but meaningless — must be treated as rejection."""
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_resolve_query_processing_hook_result(
|
||||
QueryProcessingResponse(query=" "), "original query"
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
|
||||
|
||||
|
||||
def test_absent_query_field_raises_query_rejected() -> None:
|
||||
"""query defaults to None when not provided."""
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_resolve_query_processing_hook_result(
|
||||
QueryProcessingResponse(), "original query"
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
|
||||
|
||||
|
||||
def test_rejection_message_surfaced_in_error_when_provided() -> None:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_resolve_query_processing_hook_result(
|
||||
QueryProcessingResponse(
|
||||
query=None, rejection_message="Queries about X are not allowed."
|
||||
),
|
||||
"original query",
|
||||
)
|
||||
assert "Queries about X are not allowed." in str(exc_info.value)
|
||||
|
||||
|
||||
def test_fallback_rejection_message_when_none() -> None:
|
||||
"""No rejection_message → generic fallback used in OnyxError detail."""
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_resolve_query_processing_hook_result(
|
||||
QueryProcessingResponse(query=None, rejection_message=None),
|
||||
"original query",
|
||||
)
|
||||
assert "No rejection reason was provided." in str(exc_info.value)
|
||||
|
||||
|
||||
def test_nonempty_query_rewrites_message_text() -> None:
|
||||
result = _resolve_query_processing_hook_result(
|
||||
QueryProcessingResponse(query="rewritten query"), "original query"
|
||||
)
|
||||
assert result == "rewritten query"
|
||||
|
||||
@@ -60,4 +60,4 @@ def test_non_rate_limit_error(mock_confluence_call: Mock) -> None:
|
||||
with pytest.raises(HTTPError):
|
||||
handled_call()
|
||||
|
||||
assert mock_confluence_call.call_count == 1
|
||||
assert mock_confluence_call.call_count == 5
|
||||
|
||||
@@ -0,0 +1,321 @@
|
||||
"""Unit tests for Notion connector data source API migration.
|
||||
|
||||
Tests the new data source discovery + querying flow and the
|
||||
data_source_id -> database_id parent resolution.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from onyx.connectors.notion.connector import NotionConnector
|
||||
from onyx.connectors.notion.connector import NotionDataSource
|
||||
from onyx.connectors.notion.connector import NotionPage
|
||||
|
||||
|
||||
def _make_connector() -> NotionConnector:
|
||||
connector = NotionConnector()
|
||||
connector.load_credentials({"notion_integration_token": "fake-token"})
|
||||
return connector
|
||||
|
||||
|
||||
def _mock_response(json_data: dict, status_code: int = 200) -> MagicMock:
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = json_data
|
||||
resp.status_code = status_code
|
||||
if status_code >= 400:
|
||||
resp.raise_for_status.side_effect = HTTPError(
|
||||
f"HTTP {status_code}", response=resp
|
||||
)
|
||||
else:
|
||||
resp.raise_for_status.return_value = None
|
||||
return resp
|
||||
|
||||
|
||||
class TestFetchDataSourcesForDatabase:
|
||||
def test_multi_source_database(self) -> None:
|
||||
connector = _make_connector()
|
||||
resp = _mock_response(
|
||||
{
|
||||
"object": "database",
|
||||
"id": "db-1",
|
||||
"data_sources": [
|
||||
{"id": "ds-1", "name": "Source A"},
|
||||
{"id": "ds-2", "name": "Source B"},
|
||||
],
|
||||
}
|
||||
)
|
||||
with patch(
|
||||
"onyx.connectors.notion.connector.rl_requests.get", return_value=resp
|
||||
):
|
||||
result = connector._fetch_data_sources_for_database("db-1")
|
||||
|
||||
assert result == [
|
||||
NotionDataSource(id="ds-1", name="Source A"),
|
||||
NotionDataSource(id="ds-2", name="Source B"),
|
||||
]
|
||||
|
||||
def test_single_source_database(self) -> None:
|
||||
connector = _make_connector()
|
||||
resp = _mock_response(
|
||||
{
|
||||
"object": "database",
|
||||
"id": "db-1",
|
||||
"data_sources": [{"id": "ds-1", "name": "Only Source"}],
|
||||
}
|
||||
)
|
||||
with patch(
|
||||
"onyx.connectors.notion.connector.rl_requests.get", return_value=resp
|
||||
):
|
||||
result = connector._fetch_data_sources_for_database("db-1")
|
||||
|
||||
assert result == [NotionDataSource(id="ds-1", name="Only Source")]
|
||||
|
||||
def test_404_returns_empty(self) -> None:
|
||||
connector = _make_connector()
|
||||
resp = _mock_response({"object": "error"}, status_code=404)
|
||||
with patch(
|
||||
"onyx.connectors.notion.connector.rl_requests.get", return_value=resp
|
||||
):
|
||||
result = connector._fetch_data_sources_for_database("db-missing")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestFetchDataSource:
|
||||
def test_query_returns_pages(self) -> None:
|
||||
connector = _make_connector()
|
||||
resp = _mock_response(
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"object": "page",
|
||||
"id": "page-1",
|
||||
"properties": {"Name": {"type": "title", "title": []}},
|
||||
}
|
||||
],
|
||||
"next_cursor": None,
|
||||
}
|
||||
)
|
||||
with patch(
|
||||
"onyx.connectors.notion.connector.rl_requests.post", return_value=resp
|
||||
):
|
||||
result = connector._fetch_data_source("ds-1")
|
||||
|
||||
assert len(result["results"]) == 1
|
||||
assert result["results"][0]["id"] == "page-1"
|
||||
assert result["next_cursor"] is None
|
||||
|
||||
def test_404_returns_empty_results(self) -> None:
|
||||
connector = _make_connector()
|
||||
resp = _mock_response({"object": "error"}, status_code=404)
|
||||
with patch(
|
||||
"onyx.connectors.notion.connector.rl_requests.post", return_value=resp
|
||||
):
|
||||
result = connector._fetch_data_source("ds-missing")
|
||||
|
||||
assert result == {"results": [], "next_cursor": None}
|
||||
|
||||
|
||||
class TestGetParentRawId:
|
||||
def test_database_id_parent(self) -> None:
|
||||
connector = _make_connector()
|
||||
parent = {"type": "database_id", "database_id": "db-1"}
|
||||
assert connector._get_parent_raw_id(parent) == "db-1"
|
||||
|
||||
def test_data_source_id_with_mapping(self) -> None:
|
||||
connector = _make_connector()
|
||||
connector._data_source_to_database_map["ds-1"] = "db-1"
|
||||
parent = {"type": "data_source_id", "data_source_id": "ds-1"}
|
||||
assert connector._get_parent_raw_id(parent) == "db-1"
|
||||
|
||||
def test_data_source_id_without_mapping_falls_back(self) -> None:
|
||||
connector = _make_connector()
|
||||
connector.workspace_id = "ws-1"
|
||||
parent = {"type": "data_source_id", "data_source_id": "ds-unknown"}
|
||||
assert connector._get_parent_raw_id(parent) == "ws-1"
|
||||
|
||||
def test_workspace_parent(self) -> None:
|
||||
connector = _make_connector()
|
||||
connector.workspace_id = "ws-1"
|
||||
parent = {"type": "workspace"}
|
||||
assert connector._get_parent_raw_id(parent) == "ws-1"
|
||||
|
||||
def test_page_id_parent(self) -> None:
|
||||
connector = _make_connector()
|
||||
parent = {"type": "page_id", "page_id": "page-1"}
|
||||
assert connector._get_parent_raw_id(parent) == "page-1"
|
||||
|
||||
def test_block_id_parent_with_mapping(self) -> None:
|
||||
connector = _make_connector()
|
||||
connector.workspace_id = "ws-1"
|
||||
connector._child_page_parent_map["inline-page-1"] = "containing-page-1"
|
||||
parent = {"type": "block_id"}
|
||||
assert (
|
||||
connector._get_parent_raw_id(parent, page_id="inline-page-1")
|
||||
== "containing-page-1"
|
||||
)
|
||||
|
||||
def test_block_id_parent_without_mapping_falls_back(self) -> None:
|
||||
connector = _make_connector()
|
||||
connector.workspace_id = "ws-1"
|
||||
parent = {"type": "block_id"}
|
||||
assert connector._get_parent_raw_id(parent, page_id="unknown-page") == "ws-1"
|
||||
|
||||
def test_none_parent_defaults_to_workspace(self) -> None:
|
||||
connector = _make_connector()
|
||||
connector.workspace_id = "ws-1"
|
||||
assert connector._get_parent_raw_id(None) == "ws-1"
|
||||
|
||||
|
||||
class TestReadPagesFromDatabaseMultiSource:
|
||||
def test_queries_all_data_sources(self) -> None:
|
||||
connector = _make_connector()
|
||||
connector.workspace_id = "ws-1"
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
connector,
|
||||
"_fetch_data_sources_for_database",
|
||||
return_value=[
|
||||
NotionDataSource(id="ds-1", name="Source A"),
|
||||
NotionDataSource(id="ds-2", name="Source B"),
|
||||
],
|
||||
),
|
||||
patch.object(
|
||||
connector,
|
||||
"_fetch_data_source",
|
||||
return_value={"results": [], "next_cursor": None},
|
||||
) as mock_fetch_ds,
|
||||
):
|
||||
result = connector._read_pages_from_database("db-1")
|
||||
|
||||
assert mock_fetch_ds.call_count == 2
|
||||
mock_fetch_ds.assert_any_call("ds-1", None)
|
||||
mock_fetch_ds.assert_any_call("ds-2", None)
|
||||
|
||||
assert connector._data_source_to_database_map["ds-1"] == "db-1"
|
||||
assert connector._data_source_to_database_map["ds-2"] == "db-1"
|
||||
|
||||
assert result.blocks == []
|
||||
assert result.child_page_ids == []
|
||||
assert len(result.hierarchy_nodes) == 1
|
||||
assert result.hierarchy_nodes[0].raw_node_id == "db-1"
|
||||
|
||||
def test_collects_pages_from_all_sources(self) -> None:
|
||||
connector = _make_connector()
|
||||
connector.workspace_id = "ws-1"
|
||||
connector.recursive_index_enabled = True
|
||||
|
||||
ds1_results = {
|
||||
"results": [{"object": "page", "id": "page-from-ds1", "properties": {}}],
|
||||
"next_cursor": None,
|
||||
}
|
||||
ds2_results = {
|
||||
"results": [{"object": "page", "id": "page-from-ds2", "properties": {}}],
|
||||
"next_cursor": None,
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
connector,
|
||||
"_fetch_data_sources_for_database",
|
||||
return_value=[
|
||||
NotionDataSource(id="ds-1", name="Source A"),
|
||||
NotionDataSource(id="ds-2", name="Source B"),
|
||||
],
|
||||
),
|
||||
patch.object(
|
||||
connector,
|
||||
"_fetch_data_source",
|
||||
side_effect=[ds1_results, ds2_results],
|
||||
),
|
||||
):
|
||||
result = connector._read_pages_from_database("db-1")
|
||||
|
||||
assert "page-from-ds1" in result.child_page_ids
|
||||
assert "page-from-ds2" in result.child_page_ids
|
||||
|
||||
def test_pagination_across_pages(self) -> None:
|
||||
connector = _make_connector()
|
||||
connector.workspace_id = "ws-1"
|
||||
connector.recursive_index_enabled = True
|
||||
|
||||
page1 = {
|
||||
"results": [{"object": "page", "id": "page-1", "properties": {}}],
|
||||
"next_cursor": "cursor-abc",
|
||||
}
|
||||
page2 = {
|
||||
"results": [{"object": "page", "id": "page-2", "properties": {}}],
|
||||
"next_cursor": None,
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
connector,
|
||||
"_fetch_data_sources_for_database",
|
||||
return_value=[NotionDataSource(id="ds-1", name="Source A")],
|
||||
),
|
||||
patch.object(
|
||||
connector,
|
||||
"_fetch_data_source",
|
||||
side_effect=[page1, page2],
|
||||
) as mock_fetch_ds,
|
||||
):
|
||||
result = connector._read_pages_from_database("db-1")
|
||||
|
||||
assert mock_fetch_ds.call_count == 2
|
||||
mock_fetch_ds.assert_any_call("ds-1", None)
|
||||
mock_fetch_ds.assert_any_call("ds-1", "cursor-abc")
|
||||
assert "page-1" in result.child_page_ids
|
||||
assert "page-2" in result.child_page_ids
|
||||
|
||||
|
||||
class TestInTrashField:
|
||||
def test_notion_page_accepts_in_trash(self) -> None:
|
||||
page = NotionPage(
|
||||
id="page-1",
|
||||
created_time="2026-01-01T00:00:00.000Z",
|
||||
last_edited_time="2026-01-01T00:00:00.000Z",
|
||||
in_trash=False,
|
||||
properties={},
|
||||
url="https://notion.so/page-1",
|
||||
)
|
||||
assert page.in_trash is False
|
||||
|
||||
def test_notion_page_in_trash_true(self) -> None:
|
||||
page = NotionPage(
|
||||
id="page-1",
|
||||
created_time="2026-01-01T00:00:00.000Z",
|
||||
last_edited_time="2026-01-01T00:00:00.000Z",
|
||||
in_trash=True,
|
||||
properties={},
|
||||
url="https://notion.so/page-1",
|
||||
)
|
||||
assert page.in_trash is True
|
||||
|
||||
|
||||
class TestFetchDatabaseAsPage:
|
||||
def test_handles_missing_properties(self) -> None:
|
||||
connector = _make_connector()
|
||||
resp = _mock_response(
|
||||
{
|
||||
"object": "database",
|
||||
"id": "db-1",
|
||||
"created_time": "2026-01-01T00:00:00.000Z",
|
||||
"last_edited_time": "2026-01-01T00:00:00.000Z",
|
||||
"in_trash": False,
|
||||
"url": "https://notion.so/db-1",
|
||||
"title": [{"text": {"content": "My DB"}, "plain_text": "My DB"}],
|
||||
"data_sources": [{"id": "ds-1", "name": "Source"}],
|
||||
}
|
||||
)
|
||||
with patch(
|
||||
"onyx.connectors.notion.connector.rl_requests.get", return_value=resp
|
||||
):
|
||||
page = connector._fetch_database_as_page("db-1")
|
||||
|
||||
assert page.id == "db-1"
|
||||
assert page.database_name == "My DB"
|
||||
assert page.properties == {}
|
||||
@@ -1,226 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int,
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
"""Creates a minimal DocMetadataAwareIndexChunk for testing."""
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
sections=[TextSection(text="test", link="http://test.com")],
|
||||
source=DocumentSource.FILE,
|
||||
semantic_identifier="test_doc",
|
||||
metadata={},
|
||||
)
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
return DocMetadataAwareIndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb="test",
|
||||
content="test content",
|
||||
source_links={0: "http://test.com"},
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=doc,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
embeddings={"full_embedding": [0.1] * 10, "mini_chunk_embeddings": []},
|
||||
title_embedding=[0.1] * 10,
|
||||
tenant_id="test_tenant",
|
||||
access=access,
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
ancestor_hierarchy_node_ids=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_index() -> OpenSearchDocumentIndex:
|
||||
"""Creates an OpenSearchDocumentIndex with a mocked client."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.bulk_index_documents = MagicMock()
|
||||
|
||||
tenant_state = TenantState(tenant_id="test_tenant", multitenant=False)
|
||||
|
||||
index = OpenSearchDocumentIndex.__new__(OpenSearchDocumentIndex)
|
||||
index._index_name = "test_index"
|
||||
index._client = mock_client
|
||||
index._tenant_state = tenant_state
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def _make_metadata(doc_id: str, chunk_count: int) -> IndexingMetadata:
|
||||
return IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
doc_id: IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=0,
|
||||
new_chunk_cnt=chunk_count,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_under_batch_limit_flushes_once() -> None:
|
||||
"""A document with fewer chunks than MAX_CHUNKS_PER_DOC_BATCH should flush once."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 50
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
assert index._client.bulk_index_documents.call_count == 1
|
||||
batch_arg = index._client.bulk_index_documents.call_args_list[0]
|
||||
assert len(batch_arg.kwargs["documents"]) == num_chunks
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_over_batch_limit_flushes_multiple_times() -> None:
|
||||
"""A document with more chunks than MAX_CHUNKS_PER_DOC_BATCH should flush multiple times."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 250
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 250 chunks / 100 per batch = 3 flushes (100 + 100 + 50)
|
||||
assert index._client.bulk_index_documents.call_count == 3
|
||||
batch_sizes = [
|
||||
len(call.kwargs["documents"])
|
||||
for call in index._client.bulk_index_documents.call_args_list
|
||||
]
|
||||
assert batch_sizes == [100, 100, 50]
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_exactly_at_batch_limit() -> None:
|
||||
"""A document with exactly MAX_CHUNKS_PER_DOC_BATCH chunks should flush once
|
||||
(the flush happens on the next chunk, not at the boundary)."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 100
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 100 chunks hit the >= check on chunk 101 which doesn't exist,
|
||||
# so final flush handles all 100
|
||||
# Actually: the elif fires when len(current_chunks) >= 100, which happens
|
||||
# when current_chunks has 100 items and the 101st chunk arrives.
|
||||
# With exactly 100 chunks, the 100th chunk makes len == 99, then appended -> 100.
|
||||
# No 101st chunk arrives, so the final flush handles all 100.
|
||||
assert index._client.bulk_index_documents.call_count == 1
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_one_over_batch_limit() -> None:
|
||||
"""101 chunks for one doc: first 100 flushed when the 101st arrives, then
|
||||
the 101st is flushed at the end."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 101
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
assert index._client.bulk_index_documents.call_count == 2
|
||||
batch_sizes = [
|
||||
len(call.kwargs["documents"])
|
||||
for call in index._client.bulk_index_documents.call_args_list
|
||||
]
|
||||
assert batch_sizes == [100, 1]
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_multiple_docs_each_under_limit_flush_per_doc() -> None:
|
||||
"""Multiple documents each under the batch limit should flush once per document."""
|
||||
index = _make_index()
|
||||
chunks = []
|
||||
for doc_idx in range(3):
|
||||
doc_id = f"doc_{doc_idx}"
|
||||
for chunk_idx in range(50):
|
||||
chunks.append(_make_chunk(doc_id, chunk_idx))
|
||||
|
||||
metadata = IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
f"doc_{i}": IndexingMetadata.ChunkCounts(old_chunk_cnt=0, new_chunk_cnt=50)
|
||||
for i in range(3)
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 3 documents = 3 flushes (one per doc boundary + final)
|
||||
assert index._client.bulk_index_documents.call_count == 3
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_delete_called_once_per_document() -> None:
|
||||
"""Even with multiple flushes for a single document, delete should only be
|
||||
called once per document."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 250
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0) as mock_delete:
|
||||
index.index(chunks, metadata)
|
||||
|
||||
mock_delete.assert_called_once_with(doc_id, None)
|
||||
@@ -1,152 +0,0 @@
|
||||
"""Unit tests for VespaDocumentIndex.index().
|
||||
|
||||
These tests mock all external I/O (HTTP calls, thread pools) and verify
|
||||
the streaming logic, ID cleaning/mapping, and DocumentInsertionRecord
|
||||
construction.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int = 0,
|
||||
content: str = "test content",
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test_doc",
|
||||
sections=[TextSection(text=content, link=None)],
|
||||
source=DocumentSource.NOT_APPLICABLE,
|
||||
metadata={},
|
||||
)
|
||||
index_chunk = IndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb=content[:50],
|
||||
content=content,
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=doc,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.1] * 10,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=index_chunk,
|
||||
access=access,
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
|
||||
def _make_indexing_metadata(
|
||||
doc_ids: list[str],
|
||||
old_counts: list[int],
|
||||
new_counts: list[int],
|
||||
) -> IndexingMetadata:
|
||||
return IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
doc_id: IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=old,
|
||||
new_chunk_cnt=new,
|
||||
)
|
||||
for doc_id, old, new in zip(doc_ids, old_counts, new_counts)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _stub_enrich(
|
||||
doc_id: str,
|
||||
old_chunk_cnt: int,
|
||||
) -> EnrichedDocumentIndexingInfo:
|
||||
"""Build an EnrichedDocumentIndexingInfo that says 'no chunks to delete'
|
||||
when old_chunk_cnt == 0, or 'has existing chunks' otherwise."""
|
||||
return EnrichedDocumentIndexingInfo(
|
||||
doc_id=doc_id,
|
||||
chunk_start_index=0,
|
||||
old_version=False,
|
||||
chunk_end_index=old_chunk_cnt,
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
|
||||
return_value=[],
|
||||
)
|
||||
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.BATCH_SIZE",
|
||||
3,
|
||||
)
|
||||
def test_index_respects_batch_size(
|
||||
mock_enrich: MagicMock,
|
||||
mock_get_chunk_ids: MagicMock, # noqa: ARG001
|
||||
mock_delete: MagicMock, # noqa: ARG001
|
||||
mock_batch_index: MagicMock,
|
||||
) -> None:
|
||||
"""When chunks exceed BATCH_SIZE, batch_index_vespa_chunks is called
|
||||
multiple times with correctly sized batches."""
|
||||
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
|
||||
|
||||
index = VespaDocumentIndex(
|
||||
index_name="test_index",
|
||||
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=MagicMock(),
|
||||
)
|
||||
|
||||
chunks = [_make_chunk("doc1", chunk_id=i) for i in range(7)]
|
||||
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[7])
|
||||
|
||||
results = index.index(chunks=chunks, indexing_metadata=metadata)
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
# With BATCH_SIZE=3 and 7 chunks: batches of 3, 3, 1
|
||||
assert mock_batch_index.call_count == 3
|
||||
batch_sizes = [len(c.kwargs["chunks"]) for c in mock_batch_index.call_args_list]
|
||||
assert batch_sizes == [3, 3, 1]
|
||||
|
||||
# Verify all chunks are accounted for and in order
|
||||
all_indexed = [
|
||||
chunk for c in mock_batch_index.call_args_list for chunk in c.kwargs["chunks"]
|
||||
]
|
||||
assert len(all_indexed) == 7
|
||||
assert [c.chunk_id for c in all_indexed] == list(range(7))
|
||||
@@ -7,6 +7,7 @@ from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
@@ -15,13 +16,15 @@ from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.executor import execute_hook
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.hooks.points.query_processing import QueryProcessingResponse
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
|
||||
_RESPONSE_PAYLOAD: dict[str, Any] = {"rewritten_query": "better test"}
|
||||
# A valid QueryProcessingResponse payload — used by success-path tests.
|
||||
_RESPONSE_PAYLOAD: dict[str, Any] = {"query": "better test"}
|
||||
|
||||
|
||||
def _make_hook(
|
||||
@@ -33,6 +36,7 @@ def _make_hook(
|
||||
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
|
||||
hook_id: int = 1,
|
||||
is_reachable: bool | None = None,
|
||||
hook_point: HookPoint = HookPoint.QUERY_PROCESSING,
|
||||
) -> MagicMock:
|
||||
hook = MagicMock()
|
||||
hook.is_active = is_active
|
||||
@@ -42,6 +46,7 @@ def _make_hook(
|
||||
hook.id = hook_id
|
||||
hook.fail_strategy = fail_strategy
|
||||
hook.is_reachable = is_reachable
|
||||
hook.hook_point = hook_point
|
||||
return hook
|
||||
|
||||
|
||||
@@ -140,6 +145,7 @@ def test_early_exit_returns_skipped_with_no_db_writes(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSkipped)
|
||||
@@ -152,7 +158,9 @@ def test_early_exit_returns_skipped_with_no_db_writes(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> None:
|
||||
def test_success_returns_validated_model_and_sets_reachable(
|
||||
db_session: MagicMock,
|
||||
) -> None:
|
||||
hook = _make_hook()
|
||||
|
||||
with (
|
||||
@@ -171,9 +179,11 @@ def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> No
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
assert result == _RESPONSE_PAYLOAD
|
||||
assert isinstance(result, QueryProcessingResponse)
|
||||
assert result.query == _RESPONSE_PAYLOAD["query"]
|
||||
_, update_kwargs = mock_update.call_args
|
||||
assert update_kwargs["is_reachable"] is True
|
||||
mock_log.assert_not_called()
|
||||
@@ -200,9 +210,11 @@ def test_success_skips_reachable_write_when_already_true(db_session: MagicMock)
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
assert result == _RESPONSE_PAYLOAD
|
||||
assert isinstance(result, QueryProcessingResponse)
|
||||
assert result.query == _RESPONSE_PAYLOAD["query"]
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
@@ -230,6 +242,7 @@ def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
@@ -265,6 +278,7 @@ def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
@@ -388,6 +402,7 @@ def test_http_failure_paths(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
|
||||
else:
|
||||
@@ -395,6 +410,7 @@ def test_http_failure_paths(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
assert isinstance(result, expected_type)
|
||||
|
||||
@@ -442,6 +458,7 @@ def test_authorization_header(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
_, call_kwargs = mock_client.post.call_args
|
||||
@@ -457,16 +474,16 @@ def test_authorization_header(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"http_exception,expected_result",
|
||||
"http_exception,expect_onyx_error",
|
||||
[
|
||||
pytest.param(None, _RESPONSE_PAYLOAD, id="success_path"),
|
||||
pytest.param(httpx.ConnectError("refused"), OnyxError, id="hard_fail_path"),
|
||||
pytest.param(None, False, id="success_path"),
|
||||
pytest.param(httpx.ConnectError("refused"), True, id="hard_fail_path"),
|
||||
],
|
||||
)
|
||||
def test_persist_session_failure_is_swallowed(
|
||||
db_session: MagicMock,
|
||||
http_exception: Exception | None,
|
||||
expected_result: Any,
|
||||
expect_onyx_error: bool,
|
||||
) -> None:
|
||||
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
|
||||
@@ -489,12 +506,13 @@ def test_persist_session_failure_is_swallowed(
|
||||
side_effect=http_exception,
|
||||
)
|
||||
|
||||
if expected_result is OnyxError:
|
||||
if expect_onyx_error:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
|
||||
else:
|
||||
@@ -502,8 +520,131 @@ def test_persist_session_failure_is_swallowed(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
assert result == expected_result
|
||||
assert isinstance(result, QueryProcessingResponse)
|
||||
assert result.query == _RESPONSE_PAYLOAD["query"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response model validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StrictResponse(BaseModel):
|
||||
"""Strict model used to reliably trigger a ValidationError in tests."""
|
||||
|
||||
required_field: str # no default → missing key raises ValidationError
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fail_strategy,expected_type",
|
||||
[
|
||||
pytest.param(
|
||||
HookFailStrategy.SOFT, HookSoftFailed, id="validation_failure_soft"
|
||||
),
|
||||
pytest.param(HookFailStrategy.HARD, OnyxError, id="validation_failure_hard"),
|
||||
],
|
||||
)
|
||||
def test_response_validation_failure_respects_fail_strategy(
|
||||
db_session: MagicMock,
|
||||
fail_strategy: HookFailStrategy,
|
||||
expected_type: type,
|
||||
) -> None:
|
||||
"""A response that fails response_model validation is treated like any other
|
||||
hook failure: logged, is_reachable left unchanged, fail_strategy respected."""
|
||||
hook = _make_hook(fail_strategy=fail_strategy)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
# Response payload is missing required_field → ValidationError
|
||||
_setup_client(mock_client_cls, response=_make_response(json_return={}))
|
||||
|
||||
if expected_type is OnyxError:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=_StrictResponse,
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
|
||||
else:
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=_StrictResponse,
|
||||
)
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
|
||||
# is_reachable must not be updated — server responded correctly
|
||||
mock_update.assert_not_called()
|
||||
# failure must be logged
|
||||
mock_log.assert_called_once()
|
||||
_, log_kwargs = mock_log.call_args
|
||||
assert log_kwargs["is_success"] is False
|
||||
assert "validation" in (log_kwargs["error_message"] or "").lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Outer soft-fail guard in execute_hook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fail_strategy,expected_type",
|
||||
[
|
||||
pytest.param(HookFailStrategy.SOFT, HookSoftFailed, id="unexpected_exc_soft"),
|
||||
pytest.param(HookFailStrategy.HARD, ValueError, id="unexpected_exc_hard"),
|
||||
],
|
||||
)
|
||||
def test_unexpected_exception_in_inner_respects_fail_strategy(
|
||||
db_session: MagicMock,
|
||||
fail_strategy: HookFailStrategy,
|
||||
expected_type: type,
|
||||
) -> None:
|
||||
"""An unexpected exception raised by _execute_hook_inner (not an OnyxError from
|
||||
HARD fail — e.g. a bug or an assertion error) must be swallowed and return
|
||||
HookSoftFailed for SOFT strategy, or re-raised for HARD strategy."""
|
||||
hook = _make_hook(fail_strategy=fail_strategy)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch(
|
||||
"onyx.hooks.executor._execute_hook_inner",
|
||||
side_effect=ValueError("unexpected bug"),
|
||||
),
|
||||
):
|
||||
if expected_type is HookSoftFailed:
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
else:
|
||||
with pytest.raises(ValueError, match="unexpected bug"):
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
|
||||
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:
|
||||
@@ -535,6 +676,7 @@ def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> Non
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
|
||||
@@ -116,7 +116,7 @@ def _run_adapter_build(
|
||||
project_ids_map: dict[str, list[int]],
|
||||
persona_ids_map: dict[str, list[int]],
|
||||
) -> list[DocMetadataAwareIndexChunk]:
|
||||
"""Helper that runs UserFileIndexingAdapter.prepare_enrichment + enrich_chunk
|
||||
"""Helper that runs UserFileIndexingAdapter.build_metadata_aware_chunks
|
||||
with all external dependencies mocked."""
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import (
|
||||
UserFileIndexingAdapter,
|
||||
@@ -155,12 +155,14 @@ def _run_adapter_build(
|
||||
side_effect=Exception("no LLM in tests"),
|
||||
),
|
||||
):
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id="test_tenant",
|
||||
chunks=[chunk],
|
||||
context=context,
|
||||
)
|
||||
return [enricher.enrich_chunk(chunk, 1.0)]
|
||||
|
||||
return result.chunks
|
||||
|
||||
|
||||
def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
|
||||
|
||||
153
backend/tests/unit/server/metrics/test_celery_task_metrics.py
Normal file
153
backend/tests/unit/server/metrics/test_celery_task_metrics.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Tests for generic Celery task lifecycle Prometheus metrics."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.celery_task_metrics import _task_start_times
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_COMPLETED
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_DURATION
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_STARTED
|
||||
from onyx.server.metrics.celery_task_metrics import TASKS_ACTIVE
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_metrics() -> Iterator[None]:
|
||||
"""Clear metric state between tests."""
|
||||
_task_start_times.clear()
|
||||
yield
|
||||
_task_start_times.clear()
|
||||
|
||||
|
||||
def _make_task(name: str = "test_task", queue: str = "test_queue") -> MagicMock:
|
||||
task = MagicMock()
|
||||
task.name = name
|
||||
task.request = MagicMock()
|
||||
task.request.delivery_info = {"routing_key": queue}
|
||||
return task
|
||||
|
||||
|
||||
class TestCeleryTaskPrerun:
|
||||
def test_increments_started_and_active(self) -> None:
|
||||
task = _make_task()
|
||||
before_started = TASK_STARTED.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
before_active = TASKS_ACTIVE.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
after_started = TASK_STARTED.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
after_active = TASKS_ACTIVE.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
|
||||
assert after_started == before_started + 1
|
||||
assert after_active == before_active + 1
|
||||
|
||||
def test_records_start_time(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
assert "task-1" in _task_start_times
|
||||
|
||||
def test_noop_when_task_is_none(self) -> None:
|
||||
on_celery_task_prerun("task-1", None)
|
||||
assert "task-1" not in _task_start_times
|
||||
|
||||
def test_noop_when_task_id_is_none(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun(None, task)
|
||||
# Should not crash
|
||||
|
||||
def test_handles_missing_delivery_info(self) -> None:
|
||||
task = _make_task()
|
||||
task.request.delivery_info = None
|
||||
on_celery_task_prerun("task-1", task)
|
||||
assert "task-1" in _task_start_times
|
||||
|
||||
|
||||
class TestCeleryTaskPostrun:
|
||||
def test_increments_completed_success(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
before = TASK_COMPLETED.labels(
|
||||
task_name="test_task", queue="test_queue", outcome="success"
|
||||
)._value.get()
|
||||
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
|
||||
after = TASK_COMPLETED.labels(
|
||||
task_name="test_task", queue="test_queue", outcome="success"
|
||||
)._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_increments_completed_failure(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
before = TASK_COMPLETED.labels(
|
||||
task_name="test_task", queue="test_queue", outcome="failure"
|
||||
)._value.get()
|
||||
|
||||
on_celery_task_postrun("task-1", task, "FAILURE")
|
||||
|
||||
after = TASK_COMPLETED.labels(
|
||||
task_name="test_task", queue="test_queue", outcome="failure"
|
||||
)._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_decrements_active(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
active_before = TASKS_ACTIVE.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
|
||||
active_after = TASKS_ACTIVE.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
assert active_after == active_before - 1
|
||||
|
||||
def test_observes_duration(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
before_count = TASK_DURATION.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._sum.get()
|
||||
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
|
||||
after_count = TASK_DURATION.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._sum.get()
|
||||
# Duration should have increased (at least slightly)
|
||||
assert after_count > before_count
|
||||
|
||||
def test_cleans_up_start_time(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
assert "task-1" in _task_start_times
|
||||
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
assert "task-1" not in _task_start_times
|
||||
|
||||
def test_noop_when_task_is_none(self) -> None:
|
||||
on_celery_task_postrun("task-1", None, "SUCCESS")
|
||||
|
||||
def test_handles_missing_start_time(self) -> None:
|
||||
"""Postrun without prerun should not crash."""
|
||||
task = _make_task()
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
# Should not raise
|
||||
@@ -0,0 +1,359 @@
|
||||
"""Tests for indexing pipeline Prometheus collectors."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
|
||||
|
||||
class TestQueueDepthCollector:
|
||||
def test_returns_empty_when_factory_not_set(self) -> None:
|
||||
collector = QueueDepthCollector()
|
||||
assert collector.collect() == []
|
||||
|
||||
def test_returns_empty_describe(self) -> None:
|
||||
collector = QueueDepthCollector()
|
||||
assert collector.describe() == []
|
||||
|
||||
def test_collects_queue_depths(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=5,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value={"task-1", "task-2"},
|
||||
),
|
||||
):
|
||||
families = collector.collect()
|
||||
|
||||
assert len(families) == 3
|
||||
depth_family = families[0]
|
||||
unacked_family = families[1]
|
||||
age_family = families[2]
|
||||
|
||||
assert depth_family.name == "onyx_queue_depth"
|
||||
assert len(depth_family.samples) > 0
|
||||
for sample in depth_family.samples:
|
||||
assert sample.value == 5
|
||||
|
||||
assert unacked_family.name == "onyx_queue_unacked"
|
||||
unacked_labels = {s.labels["queue"] for s in unacked_family.samples}
|
||||
assert "docfetching" in unacked_labels
|
||||
assert "docprocessing" in unacked_labels
|
||||
|
||||
assert age_family.name == "onyx_queue_oldest_task_age_seconds"
|
||||
for sample in unacked_family.samples:
|
||||
assert sample.value == 2
|
||||
|
||||
def test_handles_redis_error_gracefully(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
side_effect=Exception("connection lost"),
|
||||
):
|
||||
families = collector.collect()
|
||||
|
||||
# Returns stale cache (empty on first call)
|
||||
assert families == []
|
||||
|
||||
def test_caching_returns_stale_within_ttl(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=60)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=5,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
first = collector.collect()
|
||||
|
||||
# Second call within TTL should return cached result without calling Redis
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
side_effect=Exception("should not be called"),
|
||||
):
|
||||
second = collector.collect()
|
||||
|
||||
assert first is second # Same object, from cache
|
||||
|
||||
def test_factory_called_each_scrape(self) -> None:
|
||||
"""Verify the Redis factory is called on each fresh collect, not cached."""
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
factory = MagicMock(return_value=MagicMock())
|
||||
collector.set_redis_factory(factory)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=0,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
collector.collect()
|
||||
collector.collect()
|
||||
|
||||
assert factory.call_count == 2
|
||||
|
||||
def test_error_returns_stale_cache(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
# First call succeeds
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=10,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
good_result = collector.collect()
|
||||
|
||||
assert len(good_result) == 3
|
||||
assert good_result[0].samples[0].value == 10
|
||||
|
||||
# Second call fails — should return stale cache, not empty
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
side_effect=Exception("Redis down"),
|
||||
):
|
||||
stale_result = collector.collect()
|
||||
|
||||
assert stale_result is good_result
|
||||
|
||||
|
||||
class TestIndexAttemptCollector:
|
||||
def test_returns_empty_when_not_configured(self) -> None:
|
||||
collector = IndexAttemptCollector()
|
||||
assert collector.collect() == []
|
||||
|
||||
def test_returns_empty_describe(self) -> None:
|
||||
collector = IndexAttemptCollector()
|
||||
assert collector.describe() == []
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_collects_index_attempts(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = IndexAttemptCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = ["public"]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
from onyx.db.enums import IndexingStatus
|
||||
|
||||
mock_row = (
|
||||
IndexingStatus.IN_PROGRESS,
|
||||
MagicMock(value="web"),
|
||||
81,
|
||||
"Table Tennis Blade Guide",
|
||||
2,
|
||||
)
|
||||
mock_session.query.return_value.join.return_value.join.return_value.filter.return_value.group_by.return_value.all.return_value = [
|
||||
mock_row
|
||||
]
|
||||
|
||||
families = collector.collect()
|
||||
assert len(families) == 1
|
||||
assert families[0].name == "onyx_index_attempts_active"
|
||||
assert len(families[0].samples) == 1
|
||||
sample = families[0].samples[0]
|
||||
assert sample.labels == {
|
||||
"status": "in_progress",
|
||||
"source": "web",
|
||||
"tenant_id": "public",
|
||||
"connector_name": "Table Tennis Blade Guide",
|
||||
"cc_pair_id": "81",
|
||||
}
|
||||
assert sample.value == 2
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
def test_handles_db_error_gracefully(
|
||||
self,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = IndexAttemptCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.side_effect = Exception("DB down")
|
||||
families = collector.collect()
|
||||
# No stale cache, so returns empty
|
||||
assert families == []
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
def test_skips_none_tenant_ids(
|
||||
self,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = IndexAttemptCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = [None]
|
||||
families = collector.collect()
|
||||
assert len(families) == 1 # Returns the gauge family, just with no samples
|
||||
assert len(families[0].samples) == 0
|
||||
|
||||
|
||||
class TestConnectorHealthCollector:
|
||||
def test_returns_empty_when_not_configured(self) -> None:
|
||||
collector = ConnectorHealthCollector()
|
||||
assert collector.collect() == []
|
||||
|
||||
def test_returns_empty_describe(self) -> None:
|
||||
collector = ConnectorHealthCollector()
|
||||
assert collector.describe() == []
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_collects_connector_health(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = ConnectorHealthCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = ["public"]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
last_success = now - timedelta(hours=2)
|
||||
|
||||
mock_status = MagicMock(value="ACTIVE")
|
||||
mock_source = MagicMock(value="google_drive")
|
||||
# Row: (id, status, in_error, last_success, name, source)
|
||||
mock_row = (
|
||||
42,
|
||||
mock_status,
|
||||
True, # in_repeated_error_state
|
||||
last_success,
|
||||
"My GDrive Connector",
|
||||
mock_source,
|
||||
)
|
||||
mock_session.query.return_value.join.return_value.all.return_value = [mock_row]
|
||||
|
||||
# Mock the index attempt queries (error counts + docs counts)
|
||||
mock_session.query.return_value.filter.return_value.group_by.return_value.all.return_value = (
|
||||
[]
|
||||
)
|
||||
|
||||
families = collector.collect()
|
||||
|
||||
assert len(families) == 6
|
||||
names = {f.name for f in families}
|
||||
assert names == {
|
||||
"onyx_connector_last_success_age_seconds",
|
||||
"onyx_connector_in_error_state",
|
||||
"onyx_connectors_by_status",
|
||||
"onyx_connectors_in_error_total",
|
||||
"onyx_connector_docs_indexed",
|
||||
"onyx_connector_error_count",
|
||||
}
|
||||
|
||||
staleness = next(
|
||||
f for f in families if f.name == "onyx_connector_last_success_age_seconds"
|
||||
)
|
||||
assert len(staleness.samples) == 1
|
||||
assert staleness.samples[0].value == pytest.approx(7200, abs=5)
|
||||
|
||||
error_state = next(
|
||||
f for f in families if f.name == "onyx_connector_in_error_state"
|
||||
)
|
||||
assert error_state.samples[0].value == 1.0
|
||||
|
||||
by_status = next(f for f in families if f.name == "onyx_connectors_by_status")
|
||||
assert by_status.samples[0].labels == {
|
||||
"tenant_id": "public",
|
||||
"status": "ACTIVE",
|
||||
}
|
||||
assert by_status.samples[0].value == 1
|
||||
|
||||
error_total = next(
|
||||
f for f in families if f.name == "onyx_connectors_in_error_total"
|
||||
)
|
||||
assert error_total.samples[0].value == 1
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_skips_staleness_when_no_last_success(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = ConnectorHealthCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = ["public"]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_status = MagicMock(value="INITIAL_INDEXING")
|
||||
mock_source = MagicMock(value="slack")
|
||||
mock_row = (
|
||||
10,
|
||||
mock_status,
|
||||
False,
|
||||
None, # no last_successful_index_time
|
||||
0,
|
||||
mock_source,
|
||||
)
|
||||
mock_session.query.return_value.join.return_value.all.return_value = [mock_row]
|
||||
|
||||
families = collector.collect()
|
||||
|
||||
staleness = next(
|
||||
f for f in families if f.name == "onyx_connector_last_success_age_seconds"
|
||||
)
|
||||
assert len(staleness.samples) == 0
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
def test_handles_db_error_gracefully(
|
||||
self,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = ConnectorHealthCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.side_effect = Exception("DB down")
|
||||
families = collector.collect()
|
||||
assert families == []
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Tests for indexing pipeline setup (Redis factory caching)."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
|
||||
|
||||
|
||||
def _make_mock_app(client: MagicMock) -> MagicMock:
|
||||
"""Create a mock Celery app whose broker_connection().channel().client
|
||||
returns the given client."""
|
||||
mock_app = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.channel.return_value.client = client
|
||||
|
||||
mock_app.broker_connection.return_value = mock_conn
|
||||
|
||||
return mock_app
|
||||
|
||||
|
||||
class TestMakeBrokerRedisFactory:
|
||||
def test_caches_redis_client_across_calls(self) -> None:
|
||||
"""Factory should reuse the same client on subsequent calls."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
client1 = factory()
|
||||
client2 = factory()
|
||||
|
||||
assert client1 is client2
|
||||
# broker_connection should only be called once
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
def test_reconnects_when_ping_fails(self) -> None:
|
||||
"""Factory should create a new client if ping fails (stale connection)."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
client1 = factory()
|
||||
assert client1 is mock_client_stale
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
# Switch to fresh client for next connection
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails on stale, reconnects
|
||||
client2 = factory()
|
||||
assert client2 is mock_client_fresh
|
||||
assert mock_app.broker_connection.call_count == 2
|
||||
|
||||
def test_reconnect_closes_stale_client(self) -> None:
|
||||
"""When ping fails, the old client should be closed before reconnecting."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
factory()
|
||||
|
||||
# Switch to fresh client
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails, should close stale client
|
||||
factory()
|
||||
mock_client_stale.close.assert_called_once()
|
||||
|
||||
def test_first_call_creates_connection(self) -> None:
|
||||
"""First call should always create a new connection."""
|
||||
mock_client = MagicMock()
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
client = factory()
|
||||
|
||||
assert client is mock_client
|
||||
mock_app.broker_connection.assert_called_once()
|
||||
335
backend/tests/unit/server/metrics/test_indexing_task_metrics.py
Normal file
335
backend/tests/unit/server/metrics/test_indexing_task_metrics.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Tests for per-connector indexing task Prometheus metrics."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.indexing_task_metrics import _connector_cache
|
||||
from onyx.server.metrics.indexing_task_metrics import _indexing_start_times
|
||||
from onyx.server.metrics.indexing_task_metrics import ConnectorInfo
|
||||
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_COMPLETED
|
||||
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_DURATION
|
||||
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_STARTED
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state() -> Iterator[None]:
|
||||
"""Clear caches and state between tests.
|
||||
|
||||
Sets CURRENT_TENANT_ID_CONTEXTVAR to a realistic value so cache keys
|
||||
are never keyed on an empty string.
|
||||
"""
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test_tenant")
|
||||
_connector_cache.clear()
|
||||
_indexing_start_times.clear()
|
||||
yield
|
||||
_connector_cache.clear()
|
||||
_indexing_start_times.clear()
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _make_task(name: str) -> MagicMock:
|
||||
task = MagicMock()
|
||||
task.name = name
|
||||
return task
|
||||
|
||||
|
||||
def _mock_db_lookup(
|
||||
source: str = "google_drive", name: str = "My Google Drive"
|
||||
) -> tuple:
|
||||
"""Return (session_patch, cc_pair_patch) context managers for DB mocking."""
|
||||
mock_cc_pair = MagicMock()
|
||||
mock_cc_pair.name = name
|
||||
mock_cc_pair.connector.source.value = source
|
||||
|
||||
session_patch = patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
cc_pair_patch = patch(
|
||||
"onyx.db.connector_credential_pair.get_connector_credential_pair_from_id",
|
||||
return_value=mock_cc_pair,
|
||||
)
|
||||
return session_patch, cc_pair_patch
|
||||
|
||||
|
||||
class TestIndexingTaskPrerun:
|
||||
def test_skips_non_indexing_task(self) -> None:
|
||||
task = _make_task("some_other_task")
|
||||
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
assert "task-1" not in _indexing_start_times
|
||||
|
||||
def test_emits_started_for_docfetching(self) -> None:
|
||||
# Pre-populate cache to avoid DB lookup (tenant-scoped key)
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="google_drive", name="My Google Drive"
|
||||
)
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "tenant-1"}
|
||||
|
||||
before = INDEXING_TASK_STARTED.labels(
|
||||
task_name="connector_doc_fetching_task",
|
||||
source="google_drive",
|
||||
tenant_id="tenant-1",
|
||||
cc_pair_id="42",
|
||||
)._value.get()
|
||||
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
|
||||
after = INDEXING_TASK_STARTED.labels(
|
||||
task_name="connector_doc_fetching_task",
|
||||
source="google_drive",
|
||||
tenant_id="tenant-1",
|
||||
cc_pair_id="42",
|
||||
)._value.get()
|
||||
|
||||
assert after == before + 1
|
||||
assert "task-1" in _indexing_start_times
|
||||
|
||||
def test_emits_started_for_docprocessing(self) -> None:
|
||||
_connector_cache[("test_tenant", 10)] = ConnectorInfo(
|
||||
source="slack", name="Slack Connector"
|
||||
)
|
||||
|
||||
task = _make_task("docprocessing_task")
|
||||
kwargs = {"cc_pair_id": 10, "tenant_id": "public"}
|
||||
|
||||
on_indexing_task_prerun("task-2", task, kwargs)
|
||||
assert "task-2" in _indexing_start_times
|
||||
|
||||
def test_cache_hit_avoids_db_call(self) -> None:
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="confluence", name="Engineering Confluence"
|
||||
)
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
|
||||
|
||||
# No DB patches needed — cache should be used
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
assert "task-1" in _indexing_start_times
|
||||
|
||||
def test_db_lookup_on_cache_miss(self) -> None:
|
||||
"""On first encounter of a cc_pair_id, does a DB lookup and caches."""
|
||||
mock_cc_pair = MagicMock()
|
||||
mock_cc_pair.name = "Notion Workspace"
|
||||
mock_cc_pair.connector.source.value = "notion"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=MagicMock())
|
||||
mock_session.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_task_metrics._resolve_connector"
|
||||
) as mock_resolve,
|
||||
):
|
||||
mock_resolve.return_value = ConnectorInfo(
|
||||
source="notion", name="Notion Workspace"
|
||||
)
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 77, "tenant_id": "public"}
|
||||
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
mock_resolve.assert_called_once_with(77)
|
||||
|
||||
def test_missing_cc_pair_returns_unknown(self) -> None:
|
||||
"""When _resolve_connector can't find the cc_pair, uses 'unknown'."""
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_task_metrics._resolve_connector"
|
||||
) as mock_resolve:
|
||||
mock_resolve.return_value = ConnectorInfo(source="unknown", name="unknown")
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 999, "tenant_id": "public"}
|
||||
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
assert "task-1" in _indexing_start_times
|
||||
|
||||
def test_skips_when_cc_pair_id_missing(self) -> None:
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"tenant_id": "public"}
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
assert "task-1" not in _indexing_start_times
|
||||
|
||||
def test_db_error_does_not_crash(self) -> None:
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_task_metrics._resolve_connector",
|
||||
side_effect=Exception("DB down"),
|
||||
):
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
|
||||
# Should not raise
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
|
||||
|
||||
class TestIndexingTaskPostrun:
|
||||
def test_skips_non_indexing_task(self) -> None:
|
||||
task = _make_task("some_other_task")
|
||||
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
|
||||
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
|
||||
# Should not raise
|
||||
|
||||
def test_emits_completed_and_duration(self) -> None:
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="google_drive", name="Marketing Drive"
|
||||
)
|
||||
|
||||
task = _make_task("docprocessing_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
|
||||
|
||||
# Simulate prerun
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
|
||||
before_completed = INDEXING_TASK_COMPLETED.labels(
|
||||
task_name="docprocessing_task",
|
||||
source="google_drive",
|
||||
tenant_id="public",
|
||||
cc_pair_id="42",
|
||||
outcome="success",
|
||||
)._value.get()
|
||||
|
||||
before_duration = INDEXING_TASK_DURATION.labels(
|
||||
task_name="docprocessing_task",
|
||||
source="google_drive",
|
||||
tenant_id="public",
|
||||
)._sum.get()
|
||||
|
||||
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
|
||||
|
||||
after_completed = INDEXING_TASK_COMPLETED.labels(
|
||||
task_name="docprocessing_task",
|
||||
source="google_drive",
|
||||
tenant_id="public",
|
||||
cc_pair_id="42",
|
||||
outcome="success",
|
||||
)._value.get()
|
||||
|
||||
after_duration = INDEXING_TASK_DURATION.labels(
|
||||
task_name="docprocessing_task",
|
||||
source="google_drive",
|
||||
tenant_id="public",
|
||||
)._sum.get()
|
||||
|
||||
assert after_completed == before_completed + 1
|
||||
assert after_duration > before_duration
|
||||
|
||||
def test_failure_outcome(self) -> None:
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="slack", name="Slack"
|
||||
)
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
|
||||
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
|
||||
before = INDEXING_TASK_COMPLETED.labels(
|
||||
task_name="connector_doc_fetching_task",
|
||||
source="slack",
|
||||
tenant_id="public",
|
||||
cc_pair_id="42",
|
||||
outcome="failure",
|
||||
)._value.get()
|
||||
|
||||
on_indexing_task_postrun("task-1", task, kwargs, "FAILURE")
|
||||
|
||||
after = INDEXING_TASK_COMPLETED.labels(
|
||||
task_name="connector_doc_fetching_task",
|
||||
source="slack",
|
||||
tenant_id="public",
|
||||
cc_pair_id="42",
|
||||
outcome="failure",
|
||||
)._value.get()
|
||||
|
||||
assert after == before + 1
|
||||
|
||||
def test_handles_postrun_without_prerun(self) -> None:
|
||||
"""Postrun for an indexing task without a matching prerun should not crash."""
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="slack", name="Slack"
|
||||
)
|
||||
|
||||
task = _make_task("docprocessing_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
|
||||
|
||||
# No prerun — should still emit completed counter, just skip duration
|
||||
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
|
||||
|
||||
|
||||
class TestResolveConnector:
|
||||
def test_failed_lookup_not_cached(self) -> None:
|
||||
"""When DB lookup returns None, result should NOT be cached."""
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
|
||||
try:
|
||||
with (
|
||||
patch("onyx.db.engine.sql_engine.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.db.connector_credential_pair"
|
||||
".get_connector_credential_pair_from_id",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
|
||||
|
||||
result = _resolve_connector(999)
|
||||
assert result.source == "unknown"
|
||||
# Should NOT be cached so subsequent calls can retry
|
||||
assert ("test-tenant", 999) not in _connector_cache
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def test_exception_not_cached(self) -> None:
|
||||
"""When DB lookup raises, result should NOT be cached."""
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"onyx.db.engine.sql_engine.get_session_with_current_tenant",
|
||||
side_effect=Exception("DB down"),
|
||||
),
|
||||
):
|
||||
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
|
||||
|
||||
result = _resolve_connector(888)
|
||||
assert result.source == "unknown"
|
||||
assert ("test-tenant", 888) not in _connector_cache
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def test_successful_lookup_is_cached(self) -> None:
|
||||
"""When DB lookup succeeds, result should be cached."""
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
|
||||
try:
|
||||
mock_cc_pair = MagicMock()
|
||||
mock_cc_pair.name = "My Drive"
|
||||
mock_cc_pair.connector.source.value = "google_drive"
|
||||
|
||||
with (
|
||||
patch("onyx.db.engine.sql_engine.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.db.connector_credential_pair"
|
||||
".get_connector_credential_pair_from_id",
|
||||
return_value=mock_cc_pair,
|
||||
),
|
||||
):
|
||||
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
|
||||
|
||||
result = _resolve_connector(777)
|
||||
assert result.source == "google_drive"
|
||||
assert result.name == "My Drive"
|
||||
assert ("test-tenant", 777) in _connector_cache
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
69
backend/tests/unit/server/metrics/test_metrics_server.py
Normal file
69
backend/tests/unit/server/metrics/test_metrics_server.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Tests for the Prometheus metrics server module."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.metrics_server import _DEFAULT_PORTS
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_server_state() -> Iterator[None]:
|
||||
"""Reset the global _server_started between tests."""
|
||||
import onyx.server.metrics.metrics_server as mod
|
||||
|
||||
mod._server_started = False
|
||||
yield
|
||||
mod._server_started = False
|
||||
|
||||
|
||||
class TestStartMetricsServer:
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
def test_uses_default_port_for_known_worker(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port == _DEFAULT_PORTS["monitoring"]
|
||||
mock_start.assert_called_once_with(_DEFAULT_PORTS["monitoring"])
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
@patch.dict("os.environ", {"PROMETHEUS_METRICS_PORT": "9999"})
|
||||
def test_env_var_overrides_default(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port == 9999
|
||||
mock_start.assert_called_once_with(9999)
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
@patch.dict("os.environ", {"PROMETHEUS_METRICS_ENABLED": "false"})
|
||||
def test_disabled_via_env_var(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port is None
|
||||
mock_start.assert_not_called()
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
def test_unknown_worker_type_no_env_var(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("unknown_worker")
|
||||
assert port is None
|
||||
mock_start.assert_not_called()
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
def test_idempotent(self, mock_start: MagicMock) -> None:
|
||||
port1 = start_metrics_server("monitoring")
|
||||
port2 = start_metrics_server("monitoring")
|
||||
assert port1 == _DEFAULT_PORTS["monitoring"]
|
||||
assert port2 is None
|
||||
mock_start.assert_called_once()
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
def test_handles_os_error(self, mock_start: MagicMock) -> None:
|
||||
mock_start.side_effect = OSError("Address already in use")
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port is None
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
@patch.dict("os.environ", {"PROMETHEUS_METRICS_PORT": "not_a_number"})
|
||||
def test_invalid_port_env_var_returns_none(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port is None
|
||||
mock_start.assert_not_called()
|
||||
@@ -23,6 +23,12 @@ upstream web_server {
|
||||
# Conditionally include MCP upstream configuration
|
||||
include /etc/nginx/conf.d/mcp_upstream.conf.inc;
|
||||
|
||||
# WebSocket support: only set Connection "upgrade" for actual upgrade requests
|
||||
map $http_upgrade $connection_upgrade {
|
||||
default upgrade;
|
||||
'' close;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80 default_server;
|
||||
|
||||
@@ -46,8 +52,10 @@ server {
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
|
||||
# need to use 1.1 to support chunked transfers
|
||||
# need to use 1.1 to support chunked transfers and WebSocket
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_buffering off;
|
||||
|
||||
# timeout settings
|
||||
|
||||
@@ -23,6 +23,12 @@ upstream web_server {
|
||||
# Conditionally include MCP upstream configuration
|
||||
include /etc/nginx/conf.d/mcp_upstream.conf.inc;
|
||||
|
||||
# WebSocket support: only set Connection "upgrade" for actual upgrade requests
|
||||
map $http_upgrade $connection_upgrade {
|
||||
default upgrade;
|
||||
'' close;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80 default_server;
|
||||
|
||||
@@ -47,8 +53,10 @@ server {
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
|
||||
# need to use 1.1 to support chunked transfers
|
||||
# need to use 1.1 to support chunked transfers and WebSocket
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_buffering off;
|
||||
|
||||
# we don't want nginx trying to do something clever with
|
||||
@@ -92,6 +100,8 @@ server {
|
||||
proxy_set_header Host $host;
|
||||
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_buffering off;
|
||||
# we don't want nginx trying to do something clever with
|
||||
# redirects, we set the Host: header above already.
|
||||
|
||||
@@ -23,6 +23,12 @@ upstream web_server {
|
||||
# Conditionally include MCP upstream configuration
|
||||
include /etc/nginx/conf.d/mcp_upstream.conf.inc;
|
||||
|
||||
# WebSocket support: only set Connection "upgrade" for actual upgrade requests
|
||||
map $http_upgrade $connection_upgrade {
|
||||
default upgrade;
|
||||
'' close;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80 default_server;
|
||||
|
||||
@@ -47,8 +53,10 @@ server {
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
|
||||
# need to use 1.1 to support chunked transfers
|
||||
# need to use 1.1 to support chunked transfers and WebSocket
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_buffering off;
|
||||
|
||||
# timeout settings
|
||||
@@ -106,6 +114,8 @@ server {
|
||||
proxy_set_header Host $host;
|
||||
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_buffering off;
|
||||
|
||||
# timeout settings
|
||||
|
||||
@@ -28,6 +28,12 @@ data:
|
||||
}
|
||||
{{- end }}
|
||||
|
||||
# WebSocket support: only set Connection "upgrade" for actual upgrade requests
|
||||
map $http_upgrade $connection_upgrade {
|
||||
default upgrade;
|
||||
'' close;
|
||||
}
|
||||
|
||||
server.conf: |
|
||||
server {
|
||||
listen 1024;
|
||||
@@ -65,6 +71,8 @@ data:
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
# timeout settings
|
||||
|
||||
@@ -10,7 +10,7 @@ data:
|
||||
#!/usr/bin/env sh
|
||||
set -eu
|
||||
|
||||
HOST="${POSTGRES_HOST:-localhost}"
|
||||
HOST="${PGINTO_HOST:-${POSTGRES_HOST:-localhost}}"
|
||||
PORT="${POSTGRES_PORT:-5432}"
|
||||
USER="${POSTGRES_USER:-postgres}"
|
||||
DB="${POSTGRES_DB:-postgres}"
|
||||
|
||||
@@ -103,7 +103,7 @@ opensearch:
|
||||
- name: OPENSEARCH_INITIAL_ADMIN_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-opensearch # Must match auth.opensearch.secretName.
|
||||
name: onyx-opensearch # Must match auth.opensearch.secretName or auth.opensearch.existingSecret if defined.
|
||||
key: opensearch_admin_password # Must match auth.opensearch.secretKeys value.
|
||||
|
||||
resources:
|
||||
@@ -282,7 +282,7 @@ nginx:
|
||||
# The ingress-nginx subchart doesn't auto-detect our custom ConfigMap changes.
|
||||
# Workaround: Helm upgrade will restart if the following annotation value changes.
|
||||
podAnnotations:
|
||||
onyx.app/nginx-config-version: "1"
|
||||
onyx.app/nginx-config-version: "2"
|
||||
|
||||
# Propagate DOMAIN into nginx so server_name continues to use the same env var
|
||||
extraEnvs:
|
||||
|
||||
@@ -83,6 +83,14 @@
|
||||
"scope": [],
|
||||
"rule": "Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import "@opal/components/tooltip.css";
|
||||
import { Interactive, type InteractiveStatelessProps } from "@opal/core";
|
||||
import {
|
||||
Disabled,
|
||||
Interactive,
|
||||
type InteractiveStatelessProps,
|
||||
} from "@opal/core";
|
||||
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
@@ -32,9 +36,6 @@ type ButtonProps = InteractiveStatelessProps &
|
||||
*/
|
||||
size?: ContainerSizeVariants;
|
||||
|
||||
/** HTML button type. When provided, Container renders a `<button>` element. */
|
||||
type?: "submit" | "button" | "reset";
|
||||
|
||||
/** Tooltip text shown on hover. */
|
||||
tooltip?: string;
|
||||
|
||||
@@ -43,6 +44,9 @@ type ButtonProps = InteractiveStatelessProps &
|
||||
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
|
||||
/** Wraps the button in a Disabled context. `false` overrides parent contexts. */
|
||||
disabled?: boolean;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -59,6 +63,7 @@ function Button({
|
||||
tooltip,
|
||||
tooltipSide = "top",
|
||||
responsiveHideText = false,
|
||||
disabled,
|
||||
...interactiveProps
|
||||
}: ButtonProps) {
|
||||
const isLarge = size === "lg";
|
||||
@@ -76,7 +81,7 @@ function Button({
|
||||
) : null;
|
||||
|
||||
const button = (
|
||||
<Interactive.Stateless {...interactiveProps}>
|
||||
<Interactive.Stateless type={type} {...interactiveProps}>
|
||||
<Interactive.Container
|
||||
type={type}
|
||||
border={interactiveProps.prominence === "secondary"}
|
||||
@@ -102,9 +107,7 @@ function Button({
|
||||
</Interactive.Stateless>
|
||||
);
|
||||
|
||||
if (!tooltip) return button;
|
||||
|
||||
return (
|
||||
const result = tooltip ? (
|
||||
<TooltipPrimitive.Root>
|
||||
<TooltipPrimitive.Trigger asChild>{button}</TooltipPrimitive.Trigger>
|
||||
<TooltipPrimitive.Portal>
|
||||
@@ -117,7 +120,15 @@ function Button({
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
) : (
|
||||
button
|
||||
);
|
||||
|
||||
if (disabled != null) {
|
||||
return <Disabled disabled={disabled}>{result}</Disabled>;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
export { Button, type ButtonProps };
|
||||
|
||||
8
web/lib/opal/src/components/buttons/chevron.css
Normal file
8
web/lib/opal/src/components/buttons/chevron.css
Normal file
@@ -0,0 +1,8 @@
|
||||
.opal-button-chevron {
|
||||
transition: rotate 200ms ease;
|
||||
}
|
||||
|
||||
.interactive[data-interaction="hover"] .opal-button-chevron,
|
||||
.interactive[data-interaction="active"] .opal-button-chevron {
|
||||
rotate: -180deg;
|
||||
}
|
||||
22
web/lib/opal/src/components/buttons/chevron.tsx
Normal file
22
web/lib/opal/src/components/buttons/chevron.tsx
Normal file
@@ -0,0 +1,22 @@
|
||||
import "@opal/components/buttons/chevron.css";
|
||||
import type { IconProps } from "@opal/types";
|
||||
import { SvgChevronDownSmall } from "@opal/icons";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
/**
|
||||
* Chevron icon that rotates 180° when its parent `.interactive` enters
|
||||
* hover / active state. Shared by OpenButton, FilterButton, and any
|
||||
* future button that needs an animated dropdown indicator.
|
||||
*
|
||||
* Stable component identity — never causes React to remount the SVG.
|
||||
*/
|
||||
function ChevronIcon({ className, ...props }: IconProps) {
|
||||
return (
|
||||
<SvgChevronDownSmall
|
||||
className={cn(className, "opal-button-chevron")}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export { ChevronIcon };
|
||||
@@ -0,0 +1,107 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { FilterButton } from "@opal/components";
|
||||
import { Disabled as DisabledProvider } from "@opal/core";
|
||||
import { SvgUser, SvgActions, SvgTag } from "@opal/icons";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
|
||||
const meta: Meta<typeof FilterButton> = {
|
||||
title: "opal/components/FilterButton",
|
||||
component: FilterButton,
|
||||
tags: ["autodocs"],
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<TooltipPrimitive.Provider>
|
||||
<Story />
|
||||
</TooltipPrimitive.Provider>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof FilterButton>;
|
||||
|
||||
export const Empty: Story = {
|
||||
args: {
|
||||
icon: SvgUser,
|
||||
children: "Everyone",
|
||||
},
|
||||
};
|
||||
|
||||
export const Active: Story = {
|
||||
args: {
|
||||
icon: SvgUser,
|
||||
active: true,
|
||||
children: "By alice@example.com",
|
||||
onClear: () => console.log("clear"),
|
||||
},
|
||||
};
|
||||
|
||||
export const Open: Story = {
|
||||
args: {
|
||||
icon: SvgActions,
|
||||
interaction: "hover",
|
||||
children: "All Actions",
|
||||
},
|
||||
};
|
||||
|
||||
export const ActiveOpen: Story = {
|
||||
args: {
|
||||
icon: SvgActions,
|
||||
active: true,
|
||||
interaction: "hover",
|
||||
children: "2 selected",
|
||||
onClear: () => console.log("clear"),
|
||||
},
|
||||
};
|
||||
|
||||
export const Disabled: Story = {
|
||||
args: {
|
||||
icon: SvgTag,
|
||||
children: "All Tags",
|
||||
},
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<DisabledProvider disabled>
|
||||
<Story />
|
||||
</DisabledProvider>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export const DisabledActive: Story = {
|
||||
args: {
|
||||
icon: SvgTag,
|
||||
active: true,
|
||||
children: "2 tags",
|
||||
onClear: () => console.log("clear"),
|
||||
},
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<DisabledProvider disabled>
|
||||
<Story />
|
||||
</DisabledProvider>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export const StateComparison: Story = {
|
||||
render: () => (
|
||||
<div style={{ display: "flex", gap: 12, alignItems: "center" }}>
|
||||
<FilterButton icon={SvgUser} onClear={() => undefined}>
|
||||
Everyone
|
||||
</FilterButton>
|
||||
<FilterButton icon={SvgUser} active onClear={() => console.log("clear")}>
|
||||
By alice@example.com
|
||||
</FilterButton>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
export const WithTooltip: Story = {
|
||||
args: {
|
||||
icon: SvgUser,
|
||||
children: "Everyone",
|
||||
tooltip: "Filter by creator",
|
||||
tooltipSide: "bottom",
|
||||
},
|
||||
};
|
||||
70
web/lib/opal/src/components/buttons/filter-button/README.md
Normal file
70
web/lib/opal/src/components/buttons/filter-button/README.md
Normal file
@@ -0,0 +1,70 @@
|
||||
# FilterButton
|
||||
|
||||
**Import:** `import { FilterButton, type FilterButtonProps } from "@opal/components";`
|
||||
|
||||
A stateful filter trigger with a built-in chevron (when empty) and a clear button (when selected). Hardcodes `variant="select-filter"` and delegates to `Interactive.Stateful`, adding automatic open-state detection from Radix `data-state`. Designed to sit inside a `Popover.Trigger` for filter dropdowns.
|
||||
|
||||
## Relationship to OpenButton
|
||||
|
||||
FilterButton shares a similar call stack to `OpenButton`:
|
||||
|
||||
```
|
||||
Interactive.Stateful → Interactive.Container → content row (icon + label + trailing indicator)
|
||||
```
|
||||
|
||||
FilterButton is a **narrower, filter-specific** variant:
|
||||
|
||||
- It hardcodes `variant="select-filter"` (OpenButton uses `"select-heavy"`)
|
||||
- It exposes `active?: boolean` instead of the raw `state` prop (maps to `"selected"` / `"empty"` internally)
|
||||
- When active, the chevron is hidden via `visibility` and an absolutely-positioned clear `Button` with `prominence="tertiary"` overlays it — placed as a sibling outside the `<button>` to avoid nesting buttons
|
||||
- It uses the shared `ChevronIcon` from `buttons/chevron` (same as OpenButton)
|
||||
- It does not support `foldable`, `size`, or `width` — it is always `"lg"`
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
div.relative <- bounding wrapper
|
||||
Interactive.Stateful <- variant="select-filter", interaction, state
|
||||
└─ Interactive.Container (button) <- height="lg", default rounding/padding
|
||||
└─ div.interactive-foreground
|
||||
├─ div > Icon (interactive-foreground-icon)
|
||||
├─ <span> label text
|
||||
└─ ChevronIcon (when empty)
|
||||
OR spacer div (when selected — reserves chevron space)
|
||||
div.absolute <- clear Button overlay (when selected)
|
||||
└─ Button (SvgX, size="2xs", prominence="tertiary")
|
||||
```
|
||||
|
||||
- **Open-state detection** reads `data-state="open"` injected by Radix triggers (e.g. `Popover.Trigger`), falling back to the explicit `interaction` prop.
|
||||
- **Chevron rotation** uses the shared `ChevronIcon` component and `buttons/chevron.css`, which rotates 180deg when `data-interaction="hover"`.
|
||||
- **Clear button** is absolutely positioned outside the `<button>` element tree to avoid invalid nested `<button>` elements. An invisible spacer inside the button reserves the same space so layout doesn't shift between states.
|
||||
|
||||
## Props
|
||||
|
||||
| Prop | Type | Default | Description |
|
||||
|------|------|---------|-------------|
|
||||
| `icon` | `IconFunctionComponent` | **required** | Left icon component |
|
||||
| `children` | `string` | **required** | Label text between icon and trailing indicator |
|
||||
| `active` | `boolean` | `false` | Whether the filter has an active selection |
|
||||
| `onClear` | `() => void` | **required** | Called when the clear (X) button is clicked |
|
||||
| `interaction` | `"rest" \| "hover" \| "active"` | auto | JS-controlled interaction override. Falls back to Radix `data-state="open"`. |
|
||||
| `tooltip` | `string` | — | Tooltip text shown on hover |
|
||||
| `tooltipSide` | `TooltipSide` | `"top"` | Which side the tooltip appears on |
|
||||
|
||||
## Usage
|
||||
|
||||
```tsx
|
||||
import { FilterButton } from "@opal/components";
|
||||
import { SvgUser } from "@opal/icons";
|
||||
|
||||
// Inside a Popover (auto-detects open state)
|
||||
<Popover.Trigger asChild>
|
||||
<FilterButton
|
||||
icon={SvgUser}
|
||||
active={hasSelection}
|
||||
onClear={() => clearSelection()}
|
||||
>
|
||||
{hasSelection ? selectionLabel : "Everyone"}
|
||||
</FilterButton>
|
||||
</Popover.Trigger>
|
||||
```
|
||||
120
web/lib/opal/src/components/buttons/filter-button/components.tsx
Normal file
120
web/lib/opal/src/components/buttons/filter-button/components.tsx
Normal file
@@ -0,0 +1,120 @@
|
||||
import {
|
||||
Interactive,
|
||||
type InteractiveStatefulInteraction,
|
||||
type InteractiveStatefulProps,
|
||||
} from "@opal/core";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { SvgX } from "@opal/icons";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
|
||||
import { ChevronIcon } from "@opal/components/buttons/chevron";
|
||||
import { Button } from "@opal/components/buttons/button/components";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface FilterButtonProps
|
||||
extends Omit<InteractiveStatefulProps, "variant" | "state"> {
|
||||
/** Left icon — always visible. */
|
||||
icon: IconFunctionComponent;
|
||||
|
||||
/** Label text between icon and trailing indicator. */
|
||||
children: string;
|
||||
|
||||
/** Whether the filter has an active selection. @default false */
|
||||
active?: boolean;
|
||||
|
||||
/** Called when the clear (X) button is clicked in active state. */
|
||||
onClear: () => void;
|
||||
|
||||
/** Tooltip text shown on hover. */
|
||||
tooltip?: string;
|
||||
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// FilterButton
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function FilterButton({
|
||||
icon: Icon,
|
||||
children,
|
||||
onClear,
|
||||
tooltip,
|
||||
tooltipSide = "top",
|
||||
active = false,
|
||||
interaction,
|
||||
...statefulProps
|
||||
}: FilterButtonProps) {
|
||||
// Derive open state: explicit prop > Radix data-state (injected via Slot chain)
|
||||
const dataState = (statefulProps as Record<string, unknown>)["data-state"] as
|
||||
| string
|
||||
| undefined;
|
||||
const resolvedInteraction: InteractiveStatefulInteraction =
|
||||
interaction ?? (dataState === "open" ? "hover" : "rest");
|
||||
|
||||
const button = (
|
||||
<div className="relative">
|
||||
<Interactive.Stateful
|
||||
{...statefulProps}
|
||||
variant="select-filter"
|
||||
interaction={resolvedInteraction}
|
||||
state={active ? "selected" : "empty"}
|
||||
>
|
||||
<Interactive.Container type="button">
|
||||
<div className="interactive-foreground flex flex-row items-center gap-1">
|
||||
{iconWrapper(Icon, "lg", true)}
|
||||
<span className="whitespace-nowrap font-main-ui-action">
|
||||
{children}
|
||||
</span>
|
||||
<div style={{ visibility: active ? "hidden" : "visible" }}>
|
||||
{iconWrapper(ChevronIcon, "lg", true)}
|
||||
</div>
|
||||
</div>
|
||||
</Interactive.Container>
|
||||
</Interactive.Stateful>
|
||||
|
||||
{active && (
|
||||
<div className="absolute right-2 top-1/2 -translate-y-1/2">
|
||||
{/* Force hover state so the X stays visually prominent against
|
||||
the inverted selected background — without this it renders
|
||||
dimmed and looks disabled. */}
|
||||
<Button
|
||||
icon={SvgX}
|
||||
size="2xs"
|
||||
prominence="tertiary"
|
||||
tooltip="Clear filter"
|
||||
interaction="hover"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onClear();
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
if (!tooltip) return button;
|
||||
|
||||
return (
|
||||
<TooltipPrimitive.Root>
|
||||
<TooltipPrimitive.Trigger asChild>{button}</TooltipPrimitive.Trigger>
|
||||
<TooltipPrimitive.Portal>
|
||||
<TooltipPrimitive.Content
|
||||
className="opal-tooltip"
|
||||
side={tooltipSide}
|
||||
sideOffset={4}
|
||||
>
|
||||
{tooltip}
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
);
|
||||
}
|
||||
|
||||
export { FilterButton, type FilterButtonProps };
|
||||
@@ -1,8 +1,5 @@
|
||||
import "@opal/components/tooltip.css";
|
||||
import {
|
||||
Interactive,
|
||||
type InteractiveStatefulState,
|
||||
type InteractiveStatefulInteraction,
|
||||
type InteractiveStatefulProps,
|
||||
InteractiveContainerRoundingVariant,
|
||||
} from "@opal/core";
|
||||
@@ -22,40 +19,26 @@ type ContentPassthroughProps = DistributiveOmit<
|
||||
"paddingVariant" | "widthVariant" | "ref" | "withInteractive"
|
||||
>;
|
||||
|
||||
type LineItemButtonOwnProps = {
|
||||
type LineItemButtonOwnProps = Pick<
|
||||
InteractiveStatefulProps,
|
||||
| "state"
|
||||
| "interaction"
|
||||
| "onClick"
|
||||
| "href"
|
||||
| "target"
|
||||
| "group"
|
||||
| "ref"
|
||||
| "type"
|
||||
> & {
|
||||
/** Interactive select variant. @default "select-light" */
|
||||
selectVariant?: "select-light" | "select-heavy";
|
||||
|
||||
/** Value state. @default "empty" */
|
||||
state?: InteractiveStatefulState;
|
||||
|
||||
/** JS-controllable interaction state override. @default "rest" */
|
||||
interaction?: InteractiveStatefulInteraction;
|
||||
|
||||
/** Click handler. */
|
||||
onClick?: InteractiveStatefulProps["onClick"];
|
||||
|
||||
/** When provided, renders an anchor instead of a div. */
|
||||
href?: string;
|
||||
|
||||
/** Anchor target (e.g. "_blank"). */
|
||||
target?: string;
|
||||
|
||||
/** Interactive group key. */
|
||||
group?: string;
|
||||
|
||||
/** Forwarded ref. */
|
||||
ref?: React.Ref<HTMLElement>;
|
||||
|
||||
/** Corner rounding preset (height is always content-driven). @default "default" */
|
||||
roundingVariant?: InteractiveContainerRoundingVariant;
|
||||
|
||||
/** Container width. @default "full" */
|
||||
width?: ExtremaSizeVariants;
|
||||
|
||||
/** HTML button type. @default "button" */
|
||||
type?: "submit" | "button" | "reset";
|
||||
|
||||
/** Tooltip text shown on hover. */
|
||||
tooltip?: string;
|
||||
|
||||
@@ -79,11 +62,11 @@ function LineItemButton({
|
||||
target,
|
||||
group,
|
||||
ref,
|
||||
type = "button",
|
||||
|
||||
// Sizing
|
||||
roundingVariant = "default",
|
||||
width = "full",
|
||||
type = "button",
|
||||
tooltip,
|
||||
tooltipSide = "top",
|
||||
|
||||
|
||||
@@ -40,13 +40,6 @@ export const Open: Story = {
|
||||
},
|
||||
};
|
||||
|
||||
export const Disabled: Story = {
|
||||
args: {
|
||||
disabled: true,
|
||||
children: "Disabled",
|
||||
},
|
||||
};
|
||||
|
||||
export const Foldable: Story = {
|
||||
args: {
|
||||
foldable: true,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import "@opal/components/buttons/open-button/styles.css";
|
||||
import "@opal/components/tooltip.css";
|
||||
import {
|
||||
Interactive,
|
||||
useDisabled,
|
||||
@@ -9,24 +7,11 @@ import {
|
||||
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
|
||||
import type { InteractiveContainerRoundingVariant } from "@opal/core";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent, IconProps } from "@opal/types";
|
||||
import { SvgChevronDownSmall } from "@opal/icons";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
import { cn } from "@opal/utils";
|
||||
import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Chevron (stable identity — never causes React to remount the SVG)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function ChevronIcon({ className, ...props }: IconProps) {
|
||||
return (
|
||||
<SvgChevronDownSmall
|
||||
className={cn(className, "opal-open-button-chevron")}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
import { ChevronIcon } from "@opal/components/buttons/chevron";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
.opal-open-button-chevron {
|
||||
transition: rotate 200ms ease;
|
||||
}
|
||||
|
||||
.interactive[data-interaction="hover"] .opal-open-button-chevron,
|
||||
.interactive[data-interaction="active"] .opal-open-button-chevron {
|
||||
rotate: -180deg;
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
import "@opal/components/buttons/select-button/styles.css";
|
||||
import "@opal/components/tooltip.css";
|
||||
import {
|
||||
Interactive,
|
||||
useDisabled,
|
||||
@@ -50,9 +49,6 @@ type SelectButtonProps = InteractiveStatefulProps &
|
||||
*/
|
||||
size?: ContainerSizeVariants;
|
||||
|
||||
/** HTML button type. Container renders a `<button>` element. */
|
||||
type?: "submit" | "button" | "reset";
|
||||
|
||||
/** Tooltip text shown on hover. */
|
||||
tooltip?: string;
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import "@opal/components/tooltip.css";
|
||||
|
||||
/* Shared types */
|
||||
export type TooltipSide = "top" | "bottom" | "left" | "right";
|
||||
|
||||
@@ -19,6 +21,12 @@ export {
|
||||
type OpenButtonProps,
|
||||
} from "@opal/components/buttons/open-button/components";
|
||||
|
||||
/* FilterButton */
|
||||
export {
|
||||
FilterButton,
|
||||
type FilterButtonProps,
|
||||
} from "@opal/components/buttons/filter-button/components";
|
||||
|
||||
/* LineItemButton */
|
||||
export {
|
||||
LineItemButton,
|
||||
|
||||
@@ -32,7 +32,13 @@ function ColumnVisibilityPopover<TData extends RowData>({
|
||||
// User-defined columns only (exclude internal qualifier/actions)
|
||||
const dataColumns = table
|
||||
.getAllLeafColumns()
|
||||
.filter((col) => !col.id.startsWith("__") && col.id !== "qualifier");
|
||||
.filter(
|
||||
(col) =>
|
||||
!col.id.startsWith("__") &&
|
||||
col.id !== "qualifier" &&
|
||||
typeof col.columnDef.header === "string" &&
|
||||
col.columnDef.header.trim() !== ""
|
||||
);
|
||||
|
||||
return (
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
|
||||
@@ -145,6 +145,8 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
pageSize,
|
||||
initialSorting,
|
||||
initialColumnVisibility,
|
||||
initialRowSelection,
|
||||
initialViewSelected,
|
||||
draggable,
|
||||
footer,
|
||||
size = "lg",
|
||||
@@ -221,6 +223,8 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
pageSize: effectivePageSize,
|
||||
initialSorting,
|
||||
initialColumnVisibility,
|
||||
initialRowSelection,
|
||||
initialViewSelected,
|
||||
getRowId,
|
||||
onSelectionChange,
|
||||
searchTerm,
|
||||
|
||||
@@ -103,6 +103,10 @@ interface UseDataTableOptions<TData extends RowData> {
|
||||
initialSorting?: SortingState;
|
||||
/** Initial column visibility state. @default {} */
|
||||
initialColumnVisibility?: VisibilityState;
|
||||
/** Initial row selection state. Keys are row IDs (from `getRowId`), values are `true`. @default {} */
|
||||
initialRowSelection?: RowSelectionState;
|
||||
/** When true AND `initialRowSelection` is non-empty, start in view-selected mode (filtered to selected rows). @default false */
|
||||
initialViewSelected?: boolean;
|
||||
/** Called whenever the set of selected row IDs changes. */
|
||||
onSelectionChange?: (selectedIds: string[]) => void;
|
||||
/** Search term for global text filtering. Rows are filtered to those containing
|
||||
@@ -195,6 +199,8 @@ export default function useDataTable<TData extends RowData>(
|
||||
columnResizeMode = "onChange",
|
||||
initialSorting = [],
|
||||
initialColumnVisibility = {},
|
||||
initialRowSelection = {},
|
||||
initialViewSelected = false,
|
||||
getRowId,
|
||||
onSelectionChange,
|
||||
searchTerm,
|
||||
@@ -206,7 +212,8 @@ export default function useDataTable<TData extends RowData>(
|
||||
|
||||
// ---- internal state -----------------------------------------------------
|
||||
const [sorting, setSorting] = useState<SortingState>(initialSorting);
|
||||
const [rowSelection, setRowSelection] = useState<RowSelectionState>({});
|
||||
const [rowSelection, setRowSelection] =
|
||||
useState<RowSelectionState>(initialRowSelection);
|
||||
const [columnSizing, setColumnSizing] = useState<ColumnSizingState>({});
|
||||
const [columnVisibility, setColumnVisibility] = useState<VisibilityState>(
|
||||
initialColumnVisibility
|
||||
@@ -216,8 +223,12 @@ export default function useDataTable<TData extends RowData>(
|
||||
pageSize: pageSizeOption,
|
||||
});
|
||||
/** Combined global filter: view-mode (selected IDs) + text search. */
|
||||
const initialSelectedIds =
|
||||
initialViewSelected && Object.keys(initialRowSelection).length > 0
|
||||
? new Set(Object.keys(initialRowSelection))
|
||||
: null;
|
||||
const [globalFilter, setGlobalFilter] = useState<GlobalFilterValue>({
|
||||
selectedIds: null,
|
||||
selectedIds: initialSelectedIds,
|
||||
searchTerm: "",
|
||||
});
|
||||
|
||||
@@ -384,6 +395,31 @@ export default function useDataTable<TData extends RowData>(
|
||||
: data.length;
|
||||
const isPaginated = isFinite(pagination.pageSize);
|
||||
|
||||
// ---- keep view-mode filter in sync with selection ----------------------
|
||||
// When in view-selected mode, deselecting a row should remove it from
|
||||
// the visible set so it disappears immediately.
|
||||
useEffect(() => {
|
||||
if (isServerSide) return;
|
||||
if (globalFilter.selectedIds == null) return;
|
||||
|
||||
const currentIds = new Set(Object.keys(rowSelection));
|
||||
// Remove any ID from the filter that is no longer selected
|
||||
let changed = false;
|
||||
const next = new Set<string>();
|
||||
globalFilter.selectedIds.forEach((id) => {
|
||||
if (currentIds.has(id)) {
|
||||
next.add(id);
|
||||
} else {
|
||||
changed = true;
|
||||
}
|
||||
});
|
||||
if (changed) {
|
||||
setGlobalFilter((prev) => ({ ...prev, selectedIds: next }));
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps -- only react to
|
||||
// selection changes while in view mode
|
||||
}, [rowSelection, isServerSide]);
|
||||
|
||||
// ---- selection change callback ------------------------------------------
|
||||
const isFirstRenderRef = useRef(true);
|
||||
const onSelectionChangeRef = useRef(onSelectionChange);
|
||||
@@ -392,6 +428,10 @@ export default function useDataTable<TData extends RowData>(
|
||||
useEffect(() => {
|
||||
if (isFirstRenderRef.current) {
|
||||
isFirstRenderRef.current = false;
|
||||
// Still fire the callback on first render if there's an initial selection
|
||||
if (selectedRowIds.length > 0) {
|
||||
onSelectionChangeRef.current?.(selectedRowIds);
|
||||
}
|
||||
return;
|
||||
}
|
||||
onSelectionChangeRef.current?.(selectedRowIds);
|
||||
|
||||
@@ -146,6 +146,10 @@ export interface DataTableProps<TData> {
|
||||
initialSorting?: SortingState;
|
||||
/** Initial column visibility state. */
|
||||
initialColumnVisibility?: VisibilityState;
|
||||
/** Initial row selection state. Keys are row IDs (from `getRowId`), values are `true`. */
|
||||
initialRowSelection?: Record<string, boolean>;
|
||||
/** When true AND `initialRowSelection` is non-empty, start in view-selected mode. @default false */
|
||||
initialViewSelected?: boolean;
|
||||
/** Enable drag-and-drop row reordering. */
|
||||
draggable?: DataTableDraggableConfig;
|
||||
/** Footer configuration. */
|
||||
|
||||
@@ -88,9 +88,12 @@ function HoverableRoot({
|
||||
ref,
|
||||
onMouseEnter: consumerMouseEnter,
|
||||
onMouseLeave: consumerMouseLeave,
|
||||
onFocusCapture: consumerFocusCapture,
|
||||
onBlurCapture: consumerBlurCapture,
|
||||
...props
|
||||
}: HoverableRootProps) {
|
||||
const [hovered, setHovered] = useState(false);
|
||||
const [focused, setFocused] = useState(false);
|
||||
|
||||
const onMouseEnter = useCallback(
|
||||
(e: React.MouseEvent<HTMLDivElement>) => {
|
||||
@@ -108,16 +111,40 @@ function HoverableRoot({
|
||||
[consumerMouseLeave]
|
||||
);
|
||||
|
||||
const onFocusCapture = useCallback(
|
||||
(e: React.FocusEvent<HTMLDivElement>) => {
|
||||
setFocused(true);
|
||||
consumerFocusCapture?.(e);
|
||||
},
|
||||
[consumerFocusCapture]
|
||||
);
|
||||
|
||||
const onBlurCapture = useCallback(
|
||||
(e: React.FocusEvent<HTMLDivElement>) => {
|
||||
if (
|
||||
!(e.relatedTarget instanceof Node) ||
|
||||
!e.currentTarget.contains(e.relatedTarget)
|
||||
) {
|
||||
setFocused(false);
|
||||
}
|
||||
consumerBlurCapture?.(e);
|
||||
},
|
||||
[consumerBlurCapture]
|
||||
);
|
||||
|
||||
const active = hovered || focused;
|
||||
const GroupContext = getOrCreateContext(group);
|
||||
|
||||
return (
|
||||
<GroupContext.Provider value={hovered}>
|
||||
<GroupContext.Provider value={active}>
|
||||
<div
|
||||
{...props}
|
||||
ref={ref}
|
||||
className={cn(widthVariants[widthVariant])}
|
||||
onMouseEnter={onMouseEnter}
|
||||
onMouseLeave={onMouseLeave}
|
||||
onFocusCapture={onFocusCapture}
|
||||
onBlurCapture={onBlurCapture}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user