mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-25 09:32:45 +00:00
Compare commits
2 Commits
main
...
multi-mode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
47f4acb51e | ||
|
|
113cdb4a5d |
@@ -6,4 +6,3 @@
|
||||
|
||||
3134e5f840c12c8f32613ce520101a047c89dcc2 # refactor(whitespace): rm temporary react fragments (#7161)
|
||||
ed3f72bc75f3e3a9ae9e4d8cd38278f9c97e78b4 # refactor(whitespace): rm react fragment #7190
|
||||
7b927e79c25f4ddfd18a067f489e122acd2c89de # chore(format): format files where `ruff` and `black` agree (#9339)
|
||||
|
||||
@@ -7,15 +7,6 @@ on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths:
|
||||
- "backend/**"
|
||||
- "pyproject.toml"
|
||||
- "uv.lock"
|
||||
- ".github/workflows/pr-external-dependency-unit-tests.yml"
|
||||
- ".github/actions/setup-python-and-install-dependencies/**"
|
||||
- ".github/actions/setup-playwright/**"
|
||||
- "deployment/docker_compose/docker-compose.yml"
|
||||
- "deployment/docker_compose/docker-compose.dev.yml"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
@@ -7,13 +7,6 @@ on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths:
|
||||
- "backend/**"
|
||||
- "pyproject.toml"
|
||||
- "uv.lock"
|
||||
- ".github/workflows/pr-python-connector-tests.yml"
|
||||
- ".github/actions/setup-python-and-install-dependencies/**"
|
||||
- ".github/actions/setup-playwright/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
12
.vscode/launch.json
vendored
12
.vscode/launch.json
vendored
@@ -117,8 +117,7 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "API Server Console",
|
||||
"justMyCode": false
|
||||
"consoleTitle": "API Server Console"
|
||||
},
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
@@ -269,8 +268,7 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console",
|
||||
"justMyCode": false
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery kg_processing",
|
||||
@@ -357,8 +355,7 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user_file_processing Console",
|
||||
"justMyCode": false
|
||||
"consoleTitle": "Celery user_file_processing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery docfetching",
|
||||
@@ -416,8 +413,7 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery docprocessing Console",
|
||||
"justMyCode": false
|
||||
"consoleTitle": "Celery docprocessing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
"""group_permissions_phase1
|
||||
|
||||
Revision ID: 25a5501dc766
|
||||
Revises: b728689f45b1
|
||||
Create Date: 2026-03-23 11:41:25.557442
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "25a5501dc766"
|
||||
down_revision = "b728689f45b1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. Add account_type column to user table (nullable for now).
|
||||
# TODO(subash): backfill account_type for existing rows and add NOT NULL.
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"account_type",
|
||||
sa.Enum(AccountType, native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# 2. Add is_default column to user_group table
|
||||
op.add_column(
|
||||
"user_group",
|
||||
sa.Column(
|
||||
"is_default",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# 3. Create permission_grant table
|
||||
op.create_table(
|
||||
"permission_grant",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("group_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"permission",
|
||||
sa.Enum(Permission, native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"grant_source",
|
||||
sa.Enum(GrantSource, native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"granted_by",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"granted_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_deleted",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["group_id"],
|
||||
["user_group.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["granted_by"],
|
||||
["user.id"],
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"group_id", "permission", name="uq_permission_grant_group_permission"
|
||||
),
|
||||
)
|
||||
|
||||
# 4. Index on user__user_group(user_id) — existing composite PK
|
||||
# has user_group_id as leading column; user-filtered queries need this
|
||||
op.create_index(
|
||||
"ix_user__user_group_user_id",
|
||||
"user__user_group",
|
||||
["user_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_user__user_group_user_id", table_name="user__user_group")
|
||||
op.drop_table("permission_grant")
|
||||
op.drop_column("user_group", "is_default")
|
||||
op.drop_column("user", "account_type")
|
||||
@@ -0,0 +1,36 @@
|
||||
"""add preferred_response_id and model_display_name to chat_message
|
||||
|
||||
Revision ID: a3f8b2c1d4e5
|
||||
Create Date: 2026-03-22
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3f8b2c1d4e5"
|
||||
down_revision = "b728689f45b1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"preferred_response_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("chat_message.id"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("model_display_name", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "model_display_name")
|
||||
op.drop_column("chat_message", "preferred_response_id")
|
||||
@@ -25,13 +25,10 @@ from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
# Maximum tenants to provision in a single task run.
|
||||
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
|
||||
_MAX_TENANTS_PER_RUN = 5
|
||||
|
||||
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
|
||||
# Soft time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
# Hard time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -88,26 +85,9 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
f"To provision: {tenants_to_provision}"
|
||||
)
|
||||
|
||||
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
|
||||
if batch_size < tenants_to_provision:
|
||||
task_logger.info(
|
||||
f"Capping batch to {batch_size} "
|
||||
f"(need {tenants_to_provision}, will catch up next cycle)"
|
||||
)
|
||||
|
||||
provisioned = 0
|
||||
for i in range(batch_size):
|
||||
task_logger.info(f"Provisioning tenant {i + 1}/{batch_size}")
|
||||
try:
|
||||
if pre_provision_tenant():
|
||||
provisioned += 1
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Failed to provision tenant {i + 1}/{batch_size}, "
|
||||
"continuing with remaining tenants"
|
||||
)
|
||||
|
||||
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
|
||||
# just provision one tenant each time we run this ... increase if needed.
|
||||
if tenants_to_provision > 0:
|
||||
pre_provision_tenant()
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
@@ -121,13 +101,11 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
)
|
||||
|
||||
|
||||
def pre_provision_tenant() -> bool:
|
||||
def pre_provision_tenant() -> None:
|
||||
"""
|
||||
Pre-provision a new tenant and store it in the NewAvailableTenant table.
|
||||
This function fully sets up the tenant with all necessary configurations,
|
||||
so it's ready to be assigned to a user immediately.
|
||||
|
||||
Returns True if a tenant was successfully provisioned, False otherwise.
|
||||
"""
|
||||
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
|
||||
# rather than inside this function
|
||||
@@ -140,10 +118,10 @@ def pre_provision_tenant() -> bool:
|
||||
|
||||
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
|
||||
if not lock_provision.acquire(blocking=False):
|
||||
task_logger.warning(
|
||||
"Skipping pre_provision_tenant — could not acquire provision lock"
|
||||
task_logger.debug(
|
||||
"Skipping pre_provision_tenant task because it is already running"
|
||||
)
|
||||
return False
|
||||
return
|
||||
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
@@ -183,7 +161,6 @@ def pre_provision_tenant() -> bool:
|
||||
db_session.add(new_tenant)
|
||||
db_session.commit()
|
||||
task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}")
|
||||
return True
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
task_logger.error(
|
||||
@@ -207,7 +184,6 @@ def pre_provision_tenant() -> bool:
|
||||
asyncio.run(rollback_tenant_provisioning(tenant_id))
|
||||
except Exception:
|
||||
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
lock_provision.release()
|
||||
|
||||
@@ -115,14 +115,8 @@ def fetch_user_group_token_rate_limits_for_user(
|
||||
ordered: bool = True,
|
||||
get_editable: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
stmt = (
|
||||
select(TokenRateLimit)
|
||||
.join(
|
||||
TokenRateLimit__UserGroup,
|
||||
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
|
||||
)
|
||||
.where(TokenRateLimit__UserGroup.user_group_id == group_id)
|
||||
)
|
||||
stmt = select(TokenRateLimit)
|
||||
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
if enabled_only:
|
||||
|
||||
@@ -800,33 +800,6 @@ def update_user_group(
|
||||
return db_user_group
|
||||
|
||||
|
||||
def rename_user_group(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
new_name: str,
|
||||
) -> UserGroup:
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
db_user_group = db_session.scalar(stmt)
|
||||
if db_user_group is None:
|
||||
raise ValueError(f"UserGroup with id '{user_group_id}' not found")
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
db_user_group.name = new_name
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
# CC pair documents in Vespa contain the group name, so we need to
|
||||
# trigger a sync to update them with the new name.
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
db_user_group.is_up_to_date = False
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
|
||||
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
db_user_group = db_session.scalar(stmt)
|
||||
|
||||
@@ -56,7 +56,7 @@ def _run_single_search(
|
||||
chunk_search_request=chunk_search_request,
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
persona_search_info=None,
|
||||
persona=None, # No persona for direct search
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.persona import update_persona_access
|
||||
from ee.onyx.db.user_group import add_users_to_user_group
|
||||
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_group
|
||||
@@ -12,16 +11,13 @@ from ee.onyx.db.user_group import fetch_user_groups
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.db.user_group import insert_user_group
|
||||
from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from ee.onyx.db.user_group import rename_user_group
|
||||
from ee.onyx.db.user_group import update_user_curator_relationship
|
||||
from ee.onyx.db.user_group import update_user_group
|
||||
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
|
||||
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import UpdateGroupAgentsRequest
|
||||
from ee.onyx.server.user_group.models import UserGroup
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupRename
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
@@ -31,9 +27,6 @@ from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -94,32 +87,6 @@ def create_user_group(
|
||||
return UserGroup.from_model(db_user_group)
|
||||
|
||||
|
||||
@router.patch("/admin/user-group/rename")
|
||||
def rename_user_group_endpoint(
|
||||
rename_request: UserGroupRename,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
return UserGroup.from_model(
|
||||
rename_user_group(
|
||||
db_session=db_session,
|
||||
user_group_id=rename_request.id,
|
||||
new_name=rename_request.name,
|
||||
)
|
||||
)
|
||||
except IntegrityError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"User group with name '{rename_request.name}' already exists.",
|
||||
)
|
||||
except ValueError as e:
|
||||
msg = str(e)
|
||||
if "not found" in msg.lower():
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, msg)
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, msg)
|
||||
|
||||
|
||||
@router.patch("/admin/user-group/{user_group_id}")
|
||||
def patch_user_group(
|
||||
user_group_id: int,
|
||||
@@ -194,38 +161,3 @@ def delete_user_group(
|
||||
user_group = fetch_user_group(db_session, user_group_id)
|
||||
if user_group:
|
||||
db_delete_user_group(db_session, user_group)
|
||||
|
||||
|
||||
@router.patch("/admin/user-group/{user_group_id}/agents")
|
||||
def update_group_agents(
|
||||
user_group_id: int,
|
||||
request: UpdateGroupAgentsRequest,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
for agent_id in request.added_agent_ids:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=agent_id, user=user, db_session=db_session
|
||||
)
|
||||
current_group_ids = [g.id for g in persona.groups]
|
||||
if user_group_id not in current_group_ids:
|
||||
update_persona_access(
|
||||
persona_id=agent_id,
|
||||
creator_user_id=user.id,
|
||||
db_session=db_session,
|
||||
group_ids=current_group_ids + [user_group_id],
|
||||
)
|
||||
|
||||
for agent_id in request.removed_agent_ids:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=agent_id, user=user, db_session=db_session
|
||||
)
|
||||
current_group_ids = [g.id for g in persona.groups]
|
||||
update_persona_access(
|
||||
persona_id=agent_id,
|
||||
creator_user_id=user.id,
|
||||
db_session=db_session,
|
||||
group_ids=[gid for gid in current_group_ids if gid != user_group_id],
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -104,16 +104,6 @@ class AddUsersToUserGroupRequest(BaseModel):
|
||||
user_ids: list[UUID]
|
||||
|
||||
|
||||
class UserGroupRename(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class SetCuratorRequest(BaseModel):
|
||||
user_id: UUID
|
||||
is_curator: bool
|
||||
|
||||
|
||||
class UpdateGroupAgentsRequest(BaseModel):
|
||||
added_agent_ids: list[int]
|
||||
removed_agent_ids: list[int]
|
||||
|
||||
@@ -13,14 +13,6 @@ from celery.signals import worker_shutdown
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -42,8 +34,6 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
on_indexing_task_prerun(task_id, task, kwargs)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -58,36 +48,6 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
on_indexing_task_postrun(task_id, task, kwargs, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_retry signal doesn't pass task_id in kwargs; get it from
|
||||
# the sender (the task instance) via sender.request.id.
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_rejected sends the Consumer as sender, not the task instance.
|
||||
# The task name must be extracted from the Celery message headers.
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -116,7 +76,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("docfetching")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -14,14 +14,6 @@ from celery.signals import worker_shutdown
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -43,8 +35,6 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
on_indexing_task_prerun(task_id, task, kwargs)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -59,36 +49,6 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
on_indexing_task_postrun(task_id, task, kwargs, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_retry signal doesn't pass task_id in kwargs; get it from
|
||||
# the sender (the task instance) via sender.request.id.
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_rejected sends the Consumer as sender, not the task instance.
|
||||
# The task name must be extracted from the Celery message headers.
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -122,7 +82,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("docprocessing")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@@ -131,12 +90,6 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
# Note: worker_process_init only fires in prefork pool mode. Docprocessing uses
|
||||
# worker_pool="threads" (see configs/docprocessing.py), so this handler is
|
||||
# effectively a no-op in normal operation. It remains as a safety net in case
|
||||
# the pool type is ever changed to prefork. Prometheus metrics are safe in
|
||||
# thread-pool mode since all threads share the same process memory and can
|
||||
# update the same Counter/Gauge/Histogram objects directly.
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
@@ -54,14 +54,8 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
# Set by on_worker_init so on_worker_ready knows whether to start the server.
|
||||
_prometheus_collectors_ok: bool = False
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
global _prometheus_collectors_ok
|
||||
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
@@ -71,8 +65,6 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
_prometheus_collectors_ok = _setup_prometheus_collectors(sender)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
@@ -80,37 +72,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
def _setup_prometheus_collectors(sender: Any) -> bool:
|
||||
"""Register Prometheus collectors that need Redis/DB access.
|
||||
|
||||
Passes the Celery app so the queue depth collector can obtain a fresh
|
||||
broker Redis client on each scrape (rather than holding a stale reference).
|
||||
|
||||
Returns True if registration succeeded, False otherwise.
|
||||
"""
|
||||
try:
|
||||
from onyx.server.metrics.indexing_pipeline_setup import (
|
||||
setup_indexing_pipeline_metrics,
|
||||
)
|
||||
|
||||
setup_indexing_pipeline_metrics(sender.app)
|
||||
logger.info("Prometheus indexing pipeline collectors registered")
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to register Prometheus indexing pipeline collectors")
|
||||
return False
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
if _prometheus_collectors_ok:
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
|
||||
start_metrics_server("monitoring")
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping Prometheus metrics server — collector registration failed"
|
||||
)
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
@@ -35,7 +36,13 @@ class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
|
||||
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
|
||||
AnswerStreamPart = (
|
||||
Packet
|
||||
| MessageResponseIDInfo
|
||||
| MultiModelMessageResponseIDInfo
|
||||
| StreamingError
|
||||
| CreateChatSessionID
|
||||
)
|
||||
|
||||
AnswerStream = Iterator[AnswerStreamPart]
|
||||
|
||||
|
||||
@@ -4,9 +4,11 @@ An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import io
|
||||
import queue
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import Token
|
||||
from uuid import UUID
|
||||
|
||||
@@ -28,6 +30,7 @@ from onyx.chat.compression import calculate_total_history_tokens
|
||||
from onyx.chat.compression import compress_chat_history
|
||||
from onyx.chat.compression import find_summary_for_branch
|
||||
from onyx.chat.compression import get_compression_params
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.llm_loop import EmptyLLMResponseError
|
||||
from onyx.chat.llm_loop import run_llm_loop
|
||||
@@ -59,7 +62,8 @@ from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.chat import reserve_multi_model_message_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
@@ -69,33 +73,29 @@ from onyx.db.models import UserFile
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import verify_user_files
|
||||
from onyx.hooks.executor import execute_hook
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.hooks.points.query_processing import QueryProcessingPayload
|
||||
from onyx.hooks.points.query_processing import QueryProcessingResponse
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
from onyx.tools.constants import SEARCH_TOOL_ID
|
||||
@@ -433,28 +433,6 @@ def determine_search_params(
|
||||
)
|
||||
|
||||
|
||||
def _resolve_query_processing_hook_result(
|
||||
hook_result: QueryProcessingResponse | HookSkipped | HookSoftFailed,
|
||||
message_text: str,
|
||||
) -> str:
|
||||
"""Apply the Query Processing hook result to the message text.
|
||||
|
||||
Returns the (possibly rewritten) message text, or raises OnyxError with
|
||||
QUERY_REJECTED if the hook signals rejection (query is null or empty).
|
||||
HookSkipped and HookSoftFailed are pass-throughs — the original text is
|
||||
returned unchanged.
|
||||
"""
|
||||
if isinstance(hook_result, (HookSkipped, HookSoftFailed)):
|
||||
return message_text
|
||||
if not (hook_result.query and hook_result.query.strip()):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.QUERY_REJECTED,
|
||||
hook_result.rejection_message
|
||||
or "The hook extension for query processing did not return a valid query. No rejection reason was provided.",
|
||||
)
|
||||
return hook_result.query.strip()
|
||||
|
||||
|
||||
def handle_stream_message_objects(
|
||||
new_msg_req: SendMessageRequest,
|
||||
user: User,
|
||||
@@ -505,24 +483,16 @@ def handle_stream_message_objects(
|
||||
db_session=db_session,
|
||||
)
|
||||
yield CreateChatSessionID(chat_session_id=chat_session.id)
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
else:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
|
||||
persona = chat_session.persona
|
||||
|
||||
message_text = new_msg_req.message
|
||||
|
||||
user_identity = LLMUserIdentity(
|
||||
user_id=llm_user_identifier, session_id=str(chat_session.id)
|
||||
)
|
||||
@@ -614,28 +584,6 @@ def handle_stream_message_objects(
|
||||
if parent_message.message_type == MessageType.USER:
|
||||
user_message = parent_message
|
||||
else:
|
||||
# New message — run the Query Processing hook before saving to DB.
|
||||
# Skipped on regeneration: the message already exists and was accepted previously.
|
||||
# Skip the hook for empty/whitespace-only messages — no meaningful query
|
||||
# to process, and SendMessageRequest.message has no min_length guard.
|
||||
if message_text.strip():
|
||||
hook_result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=QueryProcessingPayload(
|
||||
query=message_text,
|
||||
# Pass None for anonymous users or authenticated users without an email
|
||||
# (e.g. some SSO flows). QueryProcessingPayload.user_email is str | None,
|
||||
# so None is accepted and serialised as null in both cases.
|
||||
user_email=None if user.is_anonymous else user.email,
|
||||
chat_session_id=str(chat_session.id),
|
||||
).model_dump(),
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
message_text = _resolve_query_processing_hook_result(
|
||||
hook_result, message_text
|
||||
)
|
||||
|
||||
user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=parent_message,
|
||||
@@ -975,17 +923,6 @@ def handle_stream_message_objects(
|
||||
state_container=state_container,
|
||||
)
|
||||
|
||||
except OnyxError as e:
|
||||
if e.error_code is not OnyxErrorCode.QUERY_REJECTED:
|
||||
log_onyx_error(e)
|
||||
yield StreamingError(
|
||||
error=e.detail,
|
||||
error_code=e.error_code.code,
|
||||
is_retryable=e.status_code >= 500,
|
||||
)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
@@ -1069,6 +1006,568 @@ def handle_stream_message_objects(
|
||||
logger.exception("Error in setting processing status")
|
||||
|
||||
|
||||
def _build_model_display_name(override: LLMOverride) -> str:
|
||||
"""Build a human-readable display name from an LLM override."""
|
||||
if override.display_name:
|
||||
return override.display_name
|
||||
if override.model_version:
|
||||
return override.model_version
|
||||
if override.model_provider:
|
||||
return override.model_provider
|
||||
return "unknown"
|
||||
|
||||
|
||||
# Sentinel placed on the merged queue when a model thread finishes.
|
||||
_MODEL_DONE = object()
|
||||
|
||||
|
||||
class _ModelIndexEmitter(Emitter):
|
||||
"""Emitter that tags packets with model_index and forwards directly to a shared queue.
|
||||
|
||||
Unlike the standard Emitter (which accumulates in a local bus), this puts
|
||||
packets into the shared merged_queue in real-time as they're emitted. This
|
||||
enables true parallel streaming — packets from multiple models interleave
|
||||
on the wire instead of arriving in bursts after each model completes.
|
||||
"""
|
||||
|
||||
def __init__(self, model_idx: int, merged_queue: queue.Queue) -> None:
|
||||
super().__init__(queue.Queue()) # bus exists for compat, unused
|
||||
self._model_idx = model_idx
|
||||
self._merged_queue = merged_queue
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
tagged_placement = Placement(
|
||||
turn_index=packet.placement.turn_index if packet.placement else 0,
|
||||
tab_index=packet.placement.tab_index if packet.placement else 0,
|
||||
sub_turn_index=(
|
||||
packet.placement.sub_turn_index if packet.placement else None
|
||||
),
|
||||
model_index=self._model_idx,
|
||||
)
|
||||
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
|
||||
self._merged_queue.put((self._model_idx, tagged_packet))
|
||||
|
||||
|
||||
def run_multi_model_stream(
|
||||
new_msg_req: SendMessageRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
llm_overrides: list[LLMOverride],
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
mcp_headers: dict[str, str] | None = None,
|
||||
) -> AnswerStream:
|
||||
# TODO: The setup logic below (session resolution through tool construction)
|
||||
# is duplicated from handle_stream_message_objects. Extract into a shared
|
||||
# _ChatStreamContext dataclass + _prepare_chat_stream_context() factory so
|
||||
# both paths call the same setup code. Tracked as follow-up refactor.
|
||||
"""Run 2-3 LLMs in parallel and yield their packets tagged with model_index.
|
||||
|
||||
Resource management:
|
||||
- Each model thread gets its OWN db_session (SQLAlchemy sessions are not thread-safe)
|
||||
- The caller's db_session is used only for setup (before threads launch) and
|
||||
completion callbacks (after threads finish)
|
||||
- ThreadPoolExecutor is bounded to len(overrides) workers
|
||||
- All threads are joined in the finally block regardless of success/failure
|
||||
- Queue-based merging avoids busy-waiting
|
||||
"""
|
||||
n_models = len(llm_overrides)
|
||||
if n_models < 2 or n_models > 3:
|
||||
raise ValueError(f"Multi-model requires 2-3 overrides, got {n_models}")
|
||||
if new_msg_req.deep_research:
|
||||
raise ValueError("Multi-model is not supported with deep research")
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
cache: CacheBackend | None = None
|
||||
chat_session: ChatSession | None = None
|
||||
|
||||
user_id = user.id
|
||||
if user.is_anonymous:
|
||||
llm_user_identifier = "anonymous_user"
|
||||
else:
|
||||
llm_user_identifier = user.email or str(user_id)
|
||||
|
||||
try:
|
||||
# ── Session setup (same as single-model path) ──────────────────
|
||||
if not new_msg_req.chat_session_id:
|
||||
if not new_msg_req.chat_session_info:
|
||||
raise RuntimeError(
|
||||
"Must specify a chat session id or chat session info"
|
||||
)
|
||||
chat_session = create_chat_session_from_request(
|
||||
chat_session_request=new_msg_req.chat_session_info,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield CreateChatSessionID(chat_session_id=chat_session.id)
|
||||
else:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
persona = chat_session.persona
|
||||
message_text = new_msg_req.message
|
||||
|
||||
# ── Build N LLM instances and validate costs ───────────────────
|
||||
llms: list[LLM] = []
|
||||
model_display_names: list[str] = []
|
||||
for override in llm_overrides:
|
||||
llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
llm_override=override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
check_llm_cost_limit_for_provider(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
llm_provider_api_key=llm.config.api_key,
|
||||
)
|
||||
llms.append(llm)
|
||||
model_display_names.append(_build_model_display_name(override))
|
||||
|
||||
# Use first LLM for token counting (context window is checked per-model
|
||||
# but token counting is model-agnostic enough for setup purposes)
|
||||
token_counter = get_llm_token_counter(llms[0])
|
||||
|
||||
verify_user_files(
|
||||
user_files=new_msg_req.file_descriptors,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
project_id=chat_session.project_id,
|
||||
)
|
||||
|
||||
# ── Chat history chain (shared across all models) ──────────────
|
||||
chat_history = create_chat_history_chain(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
|
||||
parent_message = chat_history[-1] if chat_history else root_message
|
||||
elif (
|
||||
new_msg_req.parent_message_id is None
|
||||
or new_msg_req.parent_message_id == root_message.id
|
||||
):
|
||||
parent_message = root_message
|
||||
chat_history = []
|
||||
else:
|
||||
parent_message = None
|
||||
for i in range(len(chat_history) - 1, -1, -1):
|
||||
if chat_history[i].id == new_msg_req.parent_message_id:
|
||||
parent_message = chat_history[i]
|
||||
chat_history = chat_history[: i + 1]
|
||||
break
|
||||
|
||||
if parent_message is None:
|
||||
raise ValueError(
|
||||
"The new message sent is not on the latest mainline of messages"
|
||||
)
|
||||
|
||||
if parent_message.message_type == MessageType.USER:
|
||||
user_message = parent_message
|
||||
else:
|
||||
user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=parent_message,
|
||||
message=message_text,
|
||||
token_count=token_counter(message_text),
|
||||
message_type=MessageType.USER,
|
||||
files=new_msg_req.file_descriptors,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
chat_history.append(user_message)
|
||||
|
||||
available_files = _collect_available_file_ids(
|
||||
chat_history=chat_history,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
summary_message = find_summary_for_branch(db_session, chat_history)
|
||||
summarized_file_metadata: dict[str, FileToolMetadata] = {}
|
||||
if summary_message and summary_message.last_summarized_message_id:
|
||||
cutoff_id = summary_message.last_summarized_message_id
|
||||
for msg in chat_history:
|
||||
if msg.id > cutoff_id or not msg.files:
|
||||
continue
|
||||
for fd in msg.files:
|
||||
file_id = fd.get("id")
|
||||
if not file_id:
|
||||
continue
|
||||
summarized_file_metadata[file_id] = FileToolMetadata(
|
||||
file_id=file_id,
|
||||
filename=fd.get("name") or "unknown",
|
||||
approx_char_count=0,
|
||||
)
|
||||
chat_history = [m for m in chat_history if m.id > cutoff_id]
|
||||
|
||||
user_memory_context = get_memories(user, db_session)
|
||||
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
prompt_memory_context = (
|
||||
user_memory_context
|
||||
if user.use_memories
|
||||
else user_memory_context.without_memories()
|
||||
)
|
||||
|
||||
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
|
||||
custom_agent_prompt or ""
|
||||
)
|
||||
|
||||
reserved_token_count = calculate_reserved_tokens(
|
||||
db_session=db_session,
|
||||
persona_system_prompt=max_reserved_system_prompt_tokens_str,
|
||||
token_counter=token_counter,
|
||||
files=new_msg_req.file_descriptors,
|
||||
user_memory_context=prompt_memory_context,
|
||||
)
|
||||
|
||||
context_user_files = resolve_context_user_files(
|
||||
persona=persona,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Use the smallest context window across all models for safety
|
||||
min_context_window = min(llm.config.max_input_tokens for llm in llms)
|
||||
|
||||
extracted_context_files = extract_context_files(
|
||||
user_files=context_user_files,
|
||||
llm_max_context_window=min_context_window,
|
||||
reserved_token_count=reserved_token_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
search_params = determine_search_params(
|
||||
persona_id=persona.id,
|
||||
project_id=chat_session.project_id,
|
||||
extracted_context_files=extracted_context_files,
|
||||
)
|
||||
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
search_tool_id = next(
|
||||
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
|
||||
None,
|
||||
)
|
||||
|
||||
forced_tool_id = new_msg_req.forced_tool_id
|
||||
if (
|
||||
search_params.search_usage == SearchToolUsage.DISABLED
|
||||
and forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
):
|
||||
forced_tool_id = None
|
||||
|
||||
files = load_all_chat_files(chat_history, db_session)
|
||||
chat_files_for_tools = _convert_loaded_files_to_chat_files(files)
|
||||
|
||||
# ── Reserve N assistant message IDs ────────────────────────────
|
||||
reserved_messages = reserve_multi_model_message_ids(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=user_message.id,
|
||||
model_display_names=model_display_names,
|
||||
)
|
||||
|
||||
yield MultiModelMessageResponseIDInfo(
|
||||
user_message_id=user_message.id,
|
||||
reserved_assistant_message_ids=[m.id for m in reserved_messages],
|
||||
model_names=model_display_names,
|
||||
)
|
||||
|
||||
has_file_reader_tool = any(
|
||||
tool.in_code_tool_id == "file_reader" for tool in all_tools
|
||||
)
|
||||
|
||||
chat_history_result = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
context_image_files=extracted_context_files.image_files,
|
||||
additional_context=new_msg_req.additional_context,
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
)
|
||||
simple_chat_history = chat_history_result.simple_messages
|
||||
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] = (
|
||||
chat_history_result.all_injected_file_metadata
|
||||
if has_file_reader_tool
|
||||
else {}
|
||||
)
|
||||
if summarized_file_metadata:
|
||||
for fid, meta in summarized_file_metadata.items():
|
||||
all_injected_file_metadata.setdefault(fid, meta)
|
||||
|
||||
if summary_message is not None:
|
||||
summary_simple = ChatMessageSimple(
|
||||
message=summary_message.message,
|
||||
token_count=summary_message.token_count,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
simple_chat_history.insert(0, summary_simple)
|
||||
|
||||
# ── Stop signal and processing status ──────────────────────────
|
||||
cache = get_cache_backend()
|
||||
reset_cancel_status(chat_session.id, cache)
|
||||
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session.id, cache)
|
||||
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
cache=cache,
|
||||
value=True,
|
||||
)
|
||||
|
||||
# Release the main session's read transaction before the long stream
|
||||
db_session.commit()
|
||||
|
||||
# ── Parallel model execution ───────────────────────────────────
|
||||
# Each model thread writes tagged packets to this shared queue.
|
||||
# Sentinel _MODEL_DONE signals that a thread finished.
|
||||
merged_queue: queue.Queue[tuple[int, Packet | Exception | object]] = (
|
||||
queue.Queue()
|
||||
)
|
||||
|
||||
# Track per-model state containers for completion callbacks
|
||||
state_containers: list[ChatStateContainer] = [
|
||||
ChatStateContainer() for _ in range(n_models)
|
||||
]
|
||||
# Track which models completed successfully (for completion callbacks)
|
||||
model_succeeded: list[bool] = [False] * n_models
|
||||
|
||||
user_identity = LLMUserIdentity(
|
||||
user_id=llm_user_identifier,
|
||||
session_id=str(chat_session.id),
|
||||
)
|
||||
|
||||
def _run_model(model_idx: int) -> None:
|
||||
"""Run a single model in a worker thread.
|
||||
|
||||
Uses _ModelIndexEmitter so packets flow directly to merged_queue
|
||||
in real-time (not batched after completion). This enables true
|
||||
parallel streaming where both models' tokens interleave on the wire.
|
||||
|
||||
DB access: tools may need a session during execution (e.g., search
|
||||
tool). Each thread creates its own session via context manager.
|
||||
"""
|
||||
model_emitter = _ModelIndexEmitter(model_idx, merged_queue)
|
||||
sc = state_containers[model_idx]
|
||||
model_llm = llms[model_idx]
|
||||
|
||||
try:
|
||||
# Each model thread gets its own DB session for tool execution.
|
||||
# The session is scoped to the thread and closed when done.
|
||||
with get_session_with_current_tenant() as thread_db_session:
|
||||
# Construct tools per-thread with thread-local DB session
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=thread_db_session,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id_filter=search_params.project_id_filter,
|
||||
persona_id_filter=search_params.persona_id_filter,
|
||||
bypass_acl=False,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
persona, new_msg_req.internal_search_filters
|
||||
),
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session.id,
|
||||
message_id=user_message.id,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=available_files.user_file_ids,
|
||||
chat_file_ids=available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=search_params.search_usage,
|
||||
)
|
||||
model_tools: list[Tool] = []
|
||||
for tool_list in thread_tool_dict.values():
|
||||
model_tools.extend(tool_list)
|
||||
|
||||
# Run the LLM loop — this blocks until the model finishes.
|
||||
# Packets flow to merged_queue in real-time via the emitter.
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
context_files=extracted_context_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
model_succeeded[model_idx] = True
|
||||
|
||||
except Exception as e:
|
||||
merged_queue.put((model_idx, e))
|
||||
|
||||
finally:
|
||||
merged_queue.put((model_idx, _MODEL_DONE))
|
||||
|
||||
# Launch model threads via ThreadPoolExecutor (bounded, context-propagating)
|
||||
executor = ThreadPoolExecutor(
|
||||
max_workers=n_models,
|
||||
thread_name_prefix="multi-model",
|
||||
)
|
||||
futures = []
|
||||
try:
|
||||
for i in range(n_models):
|
||||
futures.append(executor.submit(_run_model, i))
|
||||
|
||||
# ── Main thread: merge and yield packets ───────────────────
|
||||
models_remaining = n_models
|
||||
while models_remaining > 0:
|
||||
try:
|
||||
model_idx, item = merged_queue.get(timeout=0.3)
|
||||
except queue.Empty:
|
||||
# Check cancellation during idle periods
|
||||
if not check_is_connected():
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
return
|
||||
continue
|
||||
|
||||
if item is _MODEL_DONE:
|
||||
models_remaining -= 1
|
||||
continue
|
||||
|
||||
if isinstance(item, Exception):
|
||||
# Yield error as a tagged StreamingError packet
|
||||
error_msg = str(item)
|
||||
stack_trace = "".join(
|
||||
traceback.format_exception(type(item), item, item.__traceback__)
|
||||
)
|
||||
# Redact API keys from error messages
|
||||
model_llm = llms[model_idx]
|
||||
if model_llm.config.api_key and len(model_llm.config.api_key) > 2:
|
||||
error_msg = error_msg.replace(
|
||||
model_llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
stack_trace = stack_trace.replace(
|
||||
model_llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code="MODEL_ERROR",
|
||||
is_retryable=True,
|
||||
details={
|
||||
"model": model_llm.config.model_name,
|
||||
"provider": model_llm.config.model_provider,
|
||||
"model_index": model_idx,
|
||||
},
|
||||
)
|
||||
models_remaining -= 1
|
||||
continue
|
||||
|
||||
if isinstance(item, Packet):
|
||||
# Packet is already tagged with model_index by _ModelIndexEmitter
|
||||
yield item
|
||||
|
||||
# ── Completion: save each successful model's response ──────
|
||||
# Run completion callbacks on the main thread using the main
|
||||
# session. This is safe because all worker threads have exited
|
||||
# by this point (merged_queue fully drained).
|
||||
for i in range(n_models):
|
||||
if not model_succeeded[i]:
|
||||
continue
|
||||
try:
|
||||
llm_loop_completion_handle(
|
||||
state_container=state_containers[i],
|
||||
is_connected=check_is_connected,
|
||||
db_session=db_session,
|
||||
assistant_message=reserved_messages[i],
|
||||
llm=llms[i],
|
||||
reserved_tokens=reserved_token_count,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed completion for model {i} "
|
||||
f"({model_display_names[i]})"
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(type="stop", stop_reason="complete"),
|
||||
)
|
||||
|
||||
finally:
|
||||
# Ensure all threads are cleaned up regardless of how we exit
|
||||
executor.shutdown(wait=True, cancel_futures=True)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process multi-model chat message.")
|
||||
yield StreamingError(
|
||||
error=str(e),
|
||||
error_code="VALIDATION_ERROR",
|
||||
is_retryable=True,
|
||||
)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed multi-model chat: {e}")
|
||||
stack_trace = traceback.format_exc()
|
||||
yield StreamingError(
|
||||
error=str(e),
|
||||
stack_trace=stack_trace,
|
||||
error_code="MULTI_MODEL_ERROR",
|
||||
is_retryable=True,
|
||||
)
|
||||
db_session.rollback()
|
||||
|
||||
finally:
|
||||
try:
|
||||
if cache is not None and chat_session is not None:
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
cache=cache,
|
||||
value=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error clearing processing status")
|
||||
|
||||
|
||||
def llm_loop_completion_handle(
|
||||
state_container: ChatStateContainer,
|
||||
is_connected: Callable[[], bool],
|
||||
|
||||
@@ -123,7 +123,7 @@ class OnyxConfluence:
|
||||
|
||||
self.shared_base_kwargs: dict[str, str | int | bool] = {
|
||||
"api_version": "cloud" if is_cloud else "latest",
|
||||
"backoff_and_retry": False,
|
||||
"backoff_and_retry": True,
|
||||
"cloud": is_cloud,
|
||||
}
|
||||
if timeout:
|
||||
@@ -456,7 +456,7 @@ class OnyxConfluence:
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
|
||||
)
|
||||
|
||||
@@ -363,7 +363,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except requests.HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
|
||||
)
|
||||
@@ -384,7 +384,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
def _handle_http_error(e: requests.HTTPError, attempt: int, max_retries: int) -> int:
|
||||
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
@@ -408,17 +408,6 @@ def _handle_http_error(e: requests.HTTPError, attempt: int, max_retries: int) ->
|
||||
|
||||
raise e
|
||||
|
||||
if e.response.status_code >= 500:
|
||||
if attempt >= max_retries - 1:
|
||||
raise e
|
||||
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
logger.warning(
|
||||
f"Server error {e.response.status_code}. "
|
||||
f"Retrying in {delay} seconds (attempt {attempt + 1})..."
|
||||
)
|
||||
return math.ceil(time.monotonic() + delay)
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
|
||||
@@ -401,16 +401,3 @@ class SavedSearchDocWithContent(SavedSearchDoc):
|
||||
section in addition to the match_highlights."""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
class PersonaSearchInfo(BaseModel):
|
||||
"""Snapshot of persona data needed by the search pipeline.
|
||||
|
||||
Extracted from the ORM Persona before the DB session is released so that
|
||||
SearchTool and search_pipeline never lazy-load relationships post-commit.
|
||||
"""
|
||||
|
||||
document_set_names: list[str]
|
||||
search_start_date: datetime | None
|
||||
attached_document_ids: list[str]
|
||||
hierarchy_node_ids: list[int]
|
||||
|
||||
@@ -9,12 +9,12 @@ from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.context.search.preprocessing.access_filters import (
|
||||
build_access_filters_for_user,
|
||||
)
|
||||
from onyx.context.search.retrieval.search_runner import search_chunks
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
@@ -247,8 +247,8 @@ def search_pipeline(
|
||||
document_index: DocumentIndex,
|
||||
# Used for ACLs and federated search, anonymous users only see public docs
|
||||
user: User,
|
||||
# Pre-extracted persona search configuration (None when no persona)
|
||||
persona_search_info: PersonaSearchInfo | None,
|
||||
# Used for default filters and settings
|
||||
persona: Persona | None,
|
||||
db_session: Session | None = None,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
@@ -263,18 +263,24 @@ def search_pipeline(
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
persona_document_sets: list[str] | None = (
|
||||
persona_search_info.document_set_names if persona_search_info else None
|
||||
[persona_document_set.name for persona_document_set in persona.document_sets]
|
||||
if persona
|
||||
else None
|
||||
)
|
||||
persona_time_cutoff: datetime | None = (
|
||||
persona_search_info.search_start_date if persona_search_info else None
|
||||
persona.search_start_date if persona else None
|
||||
)
|
||||
|
||||
# Extract assistant knowledge filters from persona
|
||||
attached_document_ids: list[str] | None = (
|
||||
persona_search_info.attached_document_ids or None
|
||||
if persona_search_info
|
||||
[doc.id for doc in persona.attached_documents]
|
||||
if persona and persona.attached_documents
|
||||
else None
|
||||
)
|
||||
hierarchy_node_ids: list[int] | None = (
|
||||
persona_search_info.hierarchy_node_ids or None if persona_search_info else None
|
||||
[node.id for node in persona.hierarchy_nodes]
|
||||
if persona and persona.hierarchy_nodes
|
||||
else None
|
||||
)
|
||||
|
||||
filters = _build_index_filters(
|
||||
|
||||
@@ -16,7 +16,6 @@ from sqlalchemy import Row
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.exc import MultipleResultsFound
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -29,7 +28,6 @@ from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import ChatSessionSharedStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DBSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
@@ -55,22 +53,9 @@ def get_chat_session_by_id(
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
is_shared: bool = False,
|
||||
eager_load_persona: bool = False,
|
||||
) -> ChatSession:
|
||||
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
|
||||
|
||||
if eager_load_persona:
|
||||
stmt = stmt.options(
|
||||
joinedload(ChatSession.persona).options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.attached_documents),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
),
|
||||
joinedload(ChatSession.project),
|
||||
)
|
||||
|
||||
if is_shared:
|
||||
stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC)
|
||||
else:
|
||||
@@ -617,6 +602,79 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model user message.
|
||||
|
||||
Validates that the user message is a USER type and that the preferred
|
||||
assistant message is a direct child of that user message.
|
||||
"""
|
||||
user_msg = db_session.query(ChatMessage).get(user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.query(ChatMessage).get(preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child "
|
||||
f"of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -839,6 +897,8 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -750,31 +750,3 @@ def resync_cc_pair(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_connector_health_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list: # Returns list of Row tuples
|
||||
"""Return connector health data for Prometheus metrics.
|
||||
|
||||
Each row is (cc_pair_id, status, in_repeated_error_state,
|
||||
last_successful_index_time, name, source).
|
||||
"""
|
||||
return (
|
||||
db_session.query(
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.status,
|
||||
ConnectorCredentialPair.in_repeated_error_state,
|
||||
ConnectorCredentialPair.last_successful_index_time,
|
||||
ConnectorCredentialPair.name,
|
||||
Connector.source,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -1,31 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum as PyEnum
|
||||
from typing import ClassVar
|
||||
|
||||
|
||||
class AccountType(str, PyEnum):
|
||||
"""
|
||||
What kind of account this is — determines whether the user
|
||||
enters the group-based permission system.
|
||||
|
||||
STANDARD + SERVICE_ACCOUNT → participate in group system
|
||||
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
|
||||
"""
|
||||
|
||||
STANDARD = "standard"
|
||||
BOT = "bot"
|
||||
EXT_PERM_USER = "ext_perm_user"
|
||||
SERVICE_ACCOUNT = "service_account"
|
||||
ANONYMOUS = "anonymous"
|
||||
|
||||
|
||||
class GrantSource(str, PyEnum):
|
||||
"""How a permission grant was created."""
|
||||
|
||||
USER = "user"
|
||||
SCIM = "scim"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
@@ -341,54 +314,3 @@ class HookPoint(str, PyEnum):
|
||||
class HookFailStrategy(str, PyEnum):
|
||||
HARD = "hard" # exception propagates, pipeline aborts
|
||||
SOFT = "soft" # log error, return original input, pipeline continues
|
||||
|
||||
|
||||
class Permission(str, PyEnum):
|
||||
"""
|
||||
Permission tokens for group-based authorization.
|
||||
19 tokens total. full_admin_panel_access is an override —
|
||||
if present, any permission check passes.
|
||||
"""
|
||||
|
||||
# Basic (auto-granted to every new group)
|
||||
BASIC_ACCESS = "basic"
|
||||
|
||||
# Read tokens — implied only, never granted directly
|
||||
READ_CONNECTORS = "read:connectors"
|
||||
READ_DOCUMENT_SETS = "read:document_sets"
|
||||
READ_AGENTS = "read:agents"
|
||||
READ_USERS = "read:users"
|
||||
|
||||
# Add / Manage pairs
|
||||
ADD_AGENTS = "add:agents"
|
||||
MANAGE_AGENTS = "manage:agents"
|
||||
MANAGE_DOCUMENT_SETS = "manage:document_sets"
|
||||
ADD_CONNECTORS = "add:connectors"
|
||||
MANAGE_CONNECTORS = "manage:connectors"
|
||||
MANAGE_LLMS = "manage:llms"
|
||||
|
||||
# Toggle tokens
|
||||
READ_AGENT_ANALYTICS = "read:agent_analytics"
|
||||
MANAGE_ACTIONS = "manage:actions"
|
||||
READ_QUERY_HISTORY = "read:query_history"
|
||||
MANAGE_USER_GROUPS = "manage:user_groups"
|
||||
CREATE_USER_API_KEYS = "create:user_api_keys"
|
||||
CREATE_SERVICE_ACCOUNT_API_KEYS = "create:service_account_api_keys"
|
||||
CREATE_SLACK_DISCORD_BOTS = "create:slack_discord_bots"
|
||||
|
||||
# Override — any permission check passes
|
||||
FULL_ADMIN_PANEL_ACCESS = "admin"
|
||||
|
||||
# Permissions that are implied by other grants and must never be stored
|
||||
# directly in the permission_grant table.
|
||||
IMPLIED: ClassVar[frozenset[Permission]]
|
||||
|
||||
|
||||
Permission.IMPLIED = frozenset(
|
||||
{
|
||||
Permission.READ_CONNECTORS,
|
||||
Permission.READ_DOCUMENT_SETS,
|
||||
Permission.READ_AGENTS,
|
||||
Permission.READ_USERS,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -2,8 +2,6 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import NamedTuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -30,9 +28,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# from sqlalchemy.sql.selectable import Select
|
||||
|
||||
# Comment out unused imports that cause mypy errors
|
||||
@@ -977,106 +972,3 @@ def get_index_attempt_errors_for_cc_pair(
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class ActiveIndexAttemptMetric(NamedTuple):
|
||||
"""Row returned by get_active_index_attempts_for_metrics."""
|
||||
|
||||
status: IndexingStatus
|
||||
source: "DocumentSource"
|
||||
cc_pair_id: int
|
||||
cc_pair_name: str | None
|
||||
attempt_count: int
|
||||
|
||||
|
||||
def get_active_index_attempts_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list[ActiveIndexAttemptMetric]:
|
||||
"""Return non-terminal index attempts grouped by status, source, and connector.
|
||||
|
||||
Each row is (status, source, cc_pair_id, cc_pair_name, attempt_count).
|
||||
"""
|
||||
from onyx.db.models import Connector
|
||||
|
||||
terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
func.count(),
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.filter(IndexAttempt.status.notin_(terminal_statuses))
|
||||
.group_by(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return [ActiveIndexAttemptMetric(*row) for row in rows]
|
||||
|
||||
|
||||
def get_failed_attempt_counts_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: failed_attempt_count} for all connectors.
|
||||
|
||||
When ``since`` is provided, only attempts created after that timestamp
|
||||
are counted. Defaults to the last 90 days to avoid unbounded historical
|
||||
aggregation.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.count(),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.FAILED)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.all()
|
||||
)
|
||||
return {cc_id: count for cc_id, count in rows}
|
||||
|
||||
|
||||
def get_docs_indexed_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: total_new_docs_indexed} across successful attempts.
|
||||
|
||||
Only counts attempts with status SUCCESS to avoid inflating counts with
|
||||
partial results from failed attempts. When ``since`` is provided, only
|
||||
attempts created after that timestamp are included.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
|
||||
query = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.sum(func.coalesce(IndexAttempt.new_docs_indexed, 0)),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.SUCCESS)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
)
|
||||
rows = query.all()
|
||||
return {cc_id: int(total or 0) for cc_id, total in rows}
|
||||
|
||||
@@ -48,7 +48,6 @@ from sqlalchemy.types import LargeBinary
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import (
|
||||
ANONYMOUS_USER_UUID,
|
||||
@@ -79,8 +78,6 @@ from onyx.db.enums import (
|
||||
MCPAuthenticationPerformer,
|
||||
MCPTransport,
|
||||
MCPServerStatus,
|
||||
Permission,
|
||||
GrantSource,
|
||||
LLMModelFlowType,
|
||||
ThemePreference,
|
||||
DefaultAppMode,
|
||||
@@ -305,9 +302,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
account_type: Mapped[AccountType | None] = mapped_column(
|
||||
Enum(AccountType, native_enum=False), nullable=True
|
||||
)
|
||||
|
||||
"""
|
||||
Preferences probably should be in a separate table at some point, but for now
|
||||
@@ -2651,6 +2645,15 @@ class ChatMessage(Base):
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# For multi-model turns: the user message points to which assistant response
|
||||
# was selected as the preferred one to continue the conversation with.
|
||||
preferred_response_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("chat_message.id"), nullable=True
|
||||
)
|
||||
|
||||
# The display name of the model that generated this assistant message
|
||||
model_display_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# What does this message contain
|
||||
reasoning_tokens: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
@@ -2718,6 +2721,12 @@ class ChatMessage(Base):
|
||||
remote_side="ChatMessage.id",
|
||||
)
|
||||
|
||||
preferred_response: Mapped["ChatMessage | None"] = relationship(
|
||||
"ChatMessage",
|
||||
foreign_keys=[preferred_response_id],
|
||||
remote_side="ChatMessage.id",
|
||||
)
|
||||
|
||||
# Chat messages only need to know their immediate tool call children
|
||||
# If there are nested tool calls, they are stored in the tool_call_children relationship.
|
||||
tool_calls: Mapped[list["ToolCall"] | None] = relationship(
|
||||
@@ -3977,8 +3986,6 @@ class SamlAccount(Base):
|
||||
class User__UserGroup(Base):
|
||||
__tablename__ = "user__user_group"
|
||||
|
||||
__table_args__ = (Index("ix_user__user_group_user_id", "user_id"),)
|
||||
|
||||
is_curator: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
@@ -3989,48 +3996,6 @@ class User__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class PermissionGrant(Base):
|
||||
__tablename__ = "permission_grant"
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"group_id", "permission", name="uq_permission_grant_group_permission"
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
permission: Mapped[Permission] = mapped_column(
|
||||
Enum(Permission, native_enum=False), nullable=False
|
||||
)
|
||||
grant_source: Mapped[GrantSource] = mapped_column(
|
||||
Enum(GrantSource, native_enum=False), nullable=False
|
||||
)
|
||||
granted_by: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
granted_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
is_deleted: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False, server_default=text("false")
|
||||
)
|
||||
|
||||
group: Mapped["UserGroup"] = relationship(
|
||||
"UserGroup", back_populates="permission_grants"
|
||||
)
|
||||
|
||||
@validates("permission")
|
||||
def _validate_permission(self, _key: str, value: Permission) -> Permission:
|
||||
if value in Permission.IMPLIED:
|
||||
raise ValueError(
|
||||
f"{value!r} is an implied permission and cannot be granted directly"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class UserGroup__ConnectorCredentialPair(Base):
|
||||
__tablename__ = "user_group__connector_credential_pair"
|
||||
|
||||
@@ -4125,8 +4090,6 @@ class UserGroup(Base):
|
||||
is_up_for_deletion: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
# whether this is a default group (e.g. "Basic", "Admins") that cannot be deleted
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Last time a user updated this user group
|
||||
time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -4170,9 +4133,6 @@ class UserGroup(Base):
|
||||
accessible_mcp_servers: Mapped[list["MCPServer"]] = relationship(
|
||||
"MCPServer", secondary="mcp_server__user_group", back_populates="user_groups"
|
||||
)
|
||||
permission_grants: Mapped[list["PermissionGrant"]] = relationship(
|
||||
"PermissionGrant", back_populates="group", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
"""Tables related to Token Rate Limiting
|
||||
|
||||
@@ -50,18 +50,8 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_default_behavior_persona(
|
||||
db_session: Session,
|
||||
eager_load_for_tools: bool = False,
|
||||
) -> Persona | None:
|
||||
def get_default_behavior_persona(db_session: Session) -> Persona | None:
|
||||
stmt = select(Persona).where(Persona.id == DEFAULT_PERSONA_ID)
|
||||
if eager_load_for_tools:
|
||||
stmt = stmt.options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.attached_documents),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
)
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
|
||||
@@ -44,7 +44,6 @@ class OnyxErrorCode(Enum):
|
||||
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
|
||||
INVALID_INPUT = ("INVALID_INPUT", 400)
|
||||
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
|
||||
QUERY_REJECTED = ("QUERY_REJECTED", 400)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Not Found (404)
|
||||
|
||||
@@ -5,7 +5,6 @@ Usage (Celery tasks and FastAPI handlers):
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
if isinstance(result, HookSkipped):
|
||||
@@ -15,7 +14,7 @@ Usage (Celery tasks and FastAPI handlers):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is a validated Pydantic model instance (spec.response_model)
|
||||
# result is the response payload dict from the customer's endpoint
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
@@ -54,11 +53,9 @@ The executor uses three sessions:
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -84,9 +81,6 @@ class HookSoftFailed:
|
||||
"""Hook was called but failed with SOFT fail strategy — continuing."""
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -274,21 +268,22 @@ def _persist_result(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _execute_hook_inner(
|
||||
hook: Hook,
|
||||
def execute_hook(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSoftFailed:
|
||||
"""Make the HTTP call, validate the response, and return a typed model.
|
||||
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
|
||||
"""Execute the hook for the given hook point synchronously."""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
|
||||
"""
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
@@ -305,36 +300,13 @@ def _execute_hook_inner(
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(
|
||||
timeout=timeout, follow_redirects=False
|
||||
) as client: # SSRF guard: never follow redirects
|
||||
with httpx.Client(timeout=timeout) as client:
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
|
||||
# Validate the response payload against response_type.
|
||||
# A validation failure downgrades the outcome to a failure so it is logged,
|
||||
# is_reachable is left unchanged (server responded — just a bad payload),
|
||||
# and fail_strategy is respected below.
|
||||
validated_model: T | None = None
|
||||
if outcome.is_success and outcome.response_payload is not None:
|
||||
try:
|
||||
validated_model = response_type.model_validate(outcome.response_payload)
|
||||
except ValidationError as e:
|
||||
msg = (
|
||||
f"Hook response failed validation against {response_type.__name__}: {e}"
|
||||
)
|
||||
outcome = _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=outcome.status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
@@ -351,41 +323,8 @@ def _execute_hook_inner(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
|
||||
if validated_model is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"validated_model is None for successful hook call (hook_id={hook_id})",
|
||||
if outcome.response_payload is None:
|
||||
raise ValueError(
|
||||
f"response_payload is None for successful hook call (hook_id={hook_id})"
|
||||
)
|
||||
return validated_model
|
||||
|
||||
|
||||
def execute_hook(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""Execute the hook for the given hook point synchronously.
|
||||
|
||||
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
|
||||
hook failed with SOFT fail strategy, or a validated response model on success.
|
||||
Raises OnyxError on HARD failure or if the hook is misconfigured.
|
||||
"""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
fail_strategy = hook.fail_strategy
|
||||
hook_id = hook.id
|
||||
|
||||
try:
|
||||
return _execute_hook_inner(hook, payload, response_type)
|
||||
except Exception:
|
||||
if fail_strategy == HookFailStrategy.SOFT:
|
||||
logger.exception(
|
||||
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
raise
|
||||
return outcome.response_payload
|
||||
|
||||
@@ -51,12 +51,13 @@ class HookPointSpec:
|
||||
output_schema: ClassVar[dict[str, Any]]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: object) -> None:
|
||||
"""Enforce that every subclass declares all required class attributes.
|
||||
"""Enforce that every concrete subclass declares all required class attributes.
|
||||
|
||||
Called automatically by Python whenever a class inherits from HookPointSpec.
|
||||
Raises TypeError at import time if any required attribute is missing or if
|
||||
payload_model / response_model are not Pydantic BaseModel subclasses.
|
||||
input_schema and output_schema are derived automatically from the models.
|
||||
Abstract subclasses (those still carrying unimplemented abstract methods) are
|
||||
skipped — they are intermediate base classes and may not yet define everything.
|
||||
Only fully concrete subclasses are validated, ensuring a clear TypeError at
|
||||
import time rather than a confusing AttributeError at runtime.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
|
||||
|
||||
@@ -26,8 +26,6 @@ class DocumentIngestionSpec(HookPointSpec):
|
||||
default_timeout_seconds = 30.0
|
||||
fail_hard_description = "The document will not be indexed."
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
# TODO(Bo-Onyx): update later
|
||||
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.ue263ual5vdi"
|
||||
|
||||
payload_model = DocumentIngestionPayload
|
||||
response_model = DocumentIngestionResponse
|
||||
|
||||
@@ -15,7 +15,7 @@ class QueryProcessingPayload(BaseModel):
|
||||
description="Email of the user submitting the query, or null if unauthenticated."
|
||||
)
|
||||
chat_session_id: str = Field(
|
||||
description="UUID of the chat session, formatted as a hyphenated lowercase string (e.g. '550e8400-e29b-41d4-a716-446655440000'). Always present — the session is guaranteed to exist by the time this hook fires."
|
||||
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
|
||||
)
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class QueryProcessingResponse(BaseModel):
|
||||
default=None,
|
||||
description=(
|
||||
"The query to use in the pipeline. "
|
||||
"Null, empty string, whitespace-only, or absent = reject the query."
|
||||
"Null, empty string, or absent = reject the query."
|
||||
),
|
||||
)
|
||||
rejection_message: str | None = Field(
|
||||
@@ -65,8 +65,6 @@ class QueryProcessingSpec(HookPointSpec):
|
||||
"The query will be blocked and the user will see an error message."
|
||||
)
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
# TODO(Bo-Onyx): update later
|
||||
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.g2r1a1699u87"
|
||||
|
||||
payload_model = QueryProcessingPayload
|
||||
response_model = QueryProcessingResponse
|
||||
|
||||
@@ -11,6 +11,7 @@ class LLMOverride(BaseModel):
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
display_name: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
@@ -1,207 +0,0 @@
|
||||
"""Generic Celery task lifecycle Prometheus metrics.
|
||||
|
||||
Provides signal handlers that track task started/completed/failed counts,
|
||||
active task gauge, task duration histograms, and retry/reject/revoke counts.
|
||||
These fire for ALL tasks on the worker — no per-connector enrichment
|
||||
(see indexing_task_metrics.py for that).
|
||||
|
||||
Usage in a worker app module:
|
||||
from onyx.server.metrics.celery_task_metrics import (
|
||||
on_celery_task_prerun,
|
||||
on_celery_task_postrun,
|
||||
on_celery_task_retry,
|
||||
on_celery_task_revoked,
|
||||
on_celery_task_rejected,
|
||||
)
|
||||
# Call from the worker's existing signal handlers
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from celery import Task
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Gauge
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
TASK_STARTED = Counter(
|
||||
"onyx_celery_task_started_total",
|
||||
"Total Celery tasks started",
|
||||
["task_name", "queue"],
|
||||
)
|
||||
|
||||
TASK_COMPLETED = Counter(
|
||||
"onyx_celery_task_completed_total",
|
||||
"Total Celery tasks completed",
|
||||
["task_name", "queue", "outcome"],
|
||||
)
|
||||
|
||||
TASK_DURATION = Histogram(
|
||||
"onyx_celery_task_duration_seconds",
|
||||
"Celery task execution duration in seconds",
|
||||
["task_name", "queue"],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
)
|
||||
|
||||
TASKS_ACTIVE = Gauge(
|
||||
"onyx_celery_tasks_active",
|
||||
"Currently executing Celery tasks",
|
||||
["task_name", "queue"],
|
||||
)
|
||||
|
||||
TASK_RETRIED = Counter(
|
||||
"onyx_celery_task_retried_total",
|
||||
"Total Celery tasks retried",
|
||||
["task_name", "queue"],
|
||||
)
|
||||
|
||||
TASK_REVOKED = Counter(
|
||||
"onyx_celery_task_revoked_total",
|
||||
"Total Celery tasks revoked (cancelled)",
|
||||
["task_name"],
|
||||
)
|
||||
|
||||
TASK_REJECTED = Counter(
|
||||
"onyx_celery_task_rejected_total",
|
||||
"Total Celery tasks rejected by worker",
|
||||
["task_name"],
|
||||
)
|
||||
|
||||
# task_id → (monotonic start time, metric labels)
|
||||
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
|
||||
|
||||
# Lock protecting _task_start_times — prerun, postrun, and eviction may
|
||||
# run concurrently on thread-pool workers.
|
||||
_task_start_times_lock = threading.Lock()
|
||||
|
||||
# Entries older than this are evicted on each prerun to prevent unbounded
|
||||
# growth when tasks are killed (SIGTERM, OOM) and postrun never fires.
|
||||
_MAX_START_TIME_AGE_SECONDS = 3600 # 1 hour
|
||||
|
||||
|
||||
def _evict_stale_start_times() -> None:
|
||||
"""Remove _task_start_times entries older than _MAX_START_TIME_AGE_SECONDS.
|
||||
|
||||
Must be called while holding _task_start_times_lock.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
stale_ids = [
|
||||
tid
|
||||
for tid, (start, _labels) in _task_start_times.items()
|
||||
if now - start > _MAX_START_TIME_AGE_SECONDS
|
||||
]
|
||||
for tid in stale_ids:
|
||||
entry = _task_start_times.pop(tid, None)
|
||||
if entry is not None:
|
||||
_labels = entry[1]
|
||||
# Decrement active gauge for evicted tasks — these tasks were
|
||||
# started but never completed (killed, OOM, etc.).
|
||||
active_gauge = TASKS_ACTIVE.labels(**_labels)
|
||||
if active_gauge._value.get() > 0:
|
||||
active_gauge.dec()
|
||||
|
||||
|
||||
def _get_task_labels(task: Task) -> dict[str, str]:
|
||||
"""Extract task_name and queue labels from a Celery Task instance."""
|
||||
task_name = task.name or "unknown"
|
||||
queue = "unknown"
|
||||
try:
|
||||
delivery_info = task.request.delivery_info
|
||||
if delivery_info:
|
||||
queue = delivery_info.get("routing_key") or "unknown"
|
||||
except AttributeError:
|
||||
pass
|
||||
return {"task_name": task_name, "queue": queue}
|
||||
|
||||
|
||||
def on_celery_task_prerun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
) -> None:
|
||||
"""Record task start. Call from the worker's task_prerun signal handler."""
|
||||
if task is None or task_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
labels = _get_task_labels(task)
|
||||
TASK_STARTED.labels(**labels).inc()
|
||||
TASKS_ACTIVE.labels(**labels).inc()
|
||||
with _task_start_times_lock:
|
||||
_evict_stale_start_times()
|
||||
_task_start_times[task_id] = (time.monotonic(), labels)
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task prerun metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_postrun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
state: str | None,
|
||||
) -> None:
|
||||
"""Record task completion. Call from the worker's task_postrun signal handler."""
|
||||
if task is None or task_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
labels = _get_task_labels(task)
|
||||
outcome = "success" if state == "SUCCESS" else "failure"
|
||||
TASK_COMPLETED.labels(**labels, outcome=outcome).inc()
|
||||
|
||||
# Guard against going below 0 if postrun fires without a matching
|
||||
# prerun (e.g. after a worker restart or stale entry eviction).
|
||||
active_gauge = TASKS_ACTIVE.labels(**labels)
|
||||
if active_gauge._value.get() > 0:
|
||||
active_gauge.dec()
|
||||
|
||||
with _task_start_times_lock:
|
||||
entry = _task_start_times.pop(task_id, None)
|
||||
if entry is not None:
|
||||
start_time, _stored_labels = entry
|
||||
TASK_DURATION.labels(**labels).observe(time.monotonic() - start_time)
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task postrun metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_retry(
|
||||
_task_id: str | None,
|
||||
task: Task | None,
|
||||
) -> None:
|
||||
"""Record task retry. Call from the worker's task_retry signal handler."""
|
||||
if task is None:
|
||||
return
|
||||
try:
|
||||
labels = _get_task_labels(task)
|
||||
TASK_RETRIED.labels(**labels).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task retry metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_revoked(
|
||||
_task_id: str | None,
|
||||
task_name: str | None = None,
|
||||
) -> None:
|
||||
"""Record task revocation. The revoked signal doesn't provide a Task
|
||||
instance, only the task name via sender."""
|
||||
if task_name is None:
|
||||
return
|
||||
try:
|
||||
TASK_REVOKED.labels(task_name=task_name).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task revoked metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_rejected(
|
||||
_task_id: str | None,
|
||||
task_name: str | None = None,
|
||||
) -> None:
|
||||
"""Record task rejection."""
|
||||
if task_name is None:
|
||||
return
|
||||
try:
|
||||
TASK_REJECTED.labels(task_name=task_name).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task rejected metrics", exc_info=True)
|
||||
@@ -1,528 +0,0 @@
|
||||
"""Prometheus collectors for Celery queue depths and indexing pipeline state.
|
||||
|
||||
These collectors query Redis and Postgres at scrape time (the Collector pattern),
|
||||
so metrics are always fresh when Prometheus scrapes /metrics. They run inside the
|
||||
monitoring celery worker which already has Redis and DB access.
|
||||
|
||||
To avoid hammering Redis/Postgres on every 15s scrape, results are cached with
|
||||
a configurable TTL (default 30s). This means metrics may be up to TTL seconds
|
||||
stale, which is fine for monitoring dashboards.
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from prometheus_client.core import GaugeMetricFamily
|
||||
from prometheus_client.registry import Collector
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Default cache TTL in seconds. Scrapes hitting within this window return
|
||||
# the previous result without re-querying Redis/Postgres.
|
||||
_DEFAULT_CACHE_TTL = 30.0
|
||||
|
||||
_QUEUE_LABEL_MAP: dict[str, str] = {
|
||||
OnyxCeleryQueues.PRIMARY: "primary",
|
||||
OnyxCeleryQueues.DOCPROCESSING: "docprocessing",
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING: "docfetching",
|
||||
OnyxCeleryQueues.VESPA_METADATA_SYNC: "vespa_metadata_sync",
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION: "connector_deletion",
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING: "connector_pruning",
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC: "permissions_sync",
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC: "external_group_sync",
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT: "permissions_upsert",
|
||||
OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING: "hierarchy_fetching",
|
||||
OnyxCeleryQueues.LLM_MODEL_UPDATE: "llm_model_update",
|
||||
OnyxCeleryQueues.CHECKPOINT_CLEANUP: "checkpoint_cleanup",
|
||||
OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP: "index_attempt_cleanup",
|
||||
OnyxCeleryQueues.CSV_GENERATION: "csv_generation",
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING: "user_file_processing",
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC: "user_file_project_sync",
|
||||
OnyxCeleryQueues.USER_FILE_DELETE: "user_file_delete",
|
||||
OnyxCeleryQueues.MONITORING: "monitoring",
|
||||
OnyxCeleryQueues.SANDBOX: "sandbox",
|
||||
OnyxCeleryQueues.OPENSEARCH_MIGRATION: "opensearch_migration",
|
||||
}
|
||||
|
||||
# Queues where prefetched (unacked) task counts are meaningful
|
||||
_UNACKED_QUEUES: list[str] = [
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
OnyxCeleryQueues.DOCPROCESSING,
|
||||
]
|
||||
|
||||
|
||||
class _CachedCollector(Collector):
|
||||
"""Base collector with TTL-based caching.
|
||||
|
||||
Subclasses implement ``_collect_fresh()`` to query the actual data source.
|
||||
The base ``collect()`` returns cached results if the TTL hasn't expired,
|
||||
avoiding repeated queries when Prometheus scrapes frequently.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
self._cache_ttl = cache_ttl
|
||||
self._cached_result: list[GaugeMetricFamily] | None = None
|
||||
self._last_collect_time: float = 0.0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def collect(self) -> list[GaugeMetricFamily]:
|
||||
with self._lock:
|
||||
now = time.monotonic()
|
||||
if (
|
||||
now - self._last_collect_time < self._cache_ttl
|
||||
and self._cached_result is not None
|
||||
):
|
||||
return self._cached_result
|
||||
|
||||
try:
|
||||
result = self._collect_fresh()
|
||||
self._cached_result = result
|
||||
self._last_collect_time = now
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception(f"Error in {type(self).__name__}.collect()")
|
||||
# Return stale cache on error rather than nothing — avoids
|
||||
# metrics disappearing during transient failures.
|
||||
return self._cached_result if self._cached_result is not None else []
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
raise NotImplementedError
|
||||
|
||||
def describe(self) -> list[GaugeMetricFamily]:
|
||||
return []
|
||||
|
||||
|
||||
class QueueDepthCollector(_CachedCollector):
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape.
|
||||
|
||||
Uses a Redis client factory (callable) rather than a stored client
|
||||
reference so the connection is always fresh from Celery's pool.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
|
||||
depth = GaugeMetricFamily(
|
||||
"onyx_queue_depth",
|
||||
"Number of tasks waiting in Celery queue",
|
||||
labels=["queue"],
|
||||
)
|
||||
unacked = GaugeMetricFamily(
|
||||
"onyx_queue_unacked",
|
||||
"Number of prefetched (unacked) tasks for queue",
|
||||
labels=["queue"],
|
||||
)
|
||||
queue_age = GaugeMetricFamily(
|
||||
"onyx_queue_oldest_task_age_seconds",
|
||||
"Age of the oldest task in the queue (seconds since enqueue)",
|
||||
labels=["queue"],
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
|
||||
for queue_name, label in _QUEUE_LABEL_MAP.items():
|
||||
length = celery_get_queue_length(queue_name, redis_client)
|
||||
depth.add_metric([label], length)
|
||||
|
||||
# Peek at the oldest message to get its age
|
||||
if length > 0:
|
||||
age = self._get_oldest_message_age(redis_client, queue_name, now)
|
||||
if age is not None:
|
||||
queue_age.add_metric([label], age)
|
||||
|
||||
for queue_name in _UNACKED_QUEUES:
|
||||
label = _QUEUE_LABEL_MAP[queue_name]
|
||||
task_ids = celery_get_unacked_task_ids(queue_name, redis_client)
|
||||
unacked.add_metric([label], len(task_ids))
|
||||
|
||||
return [depth, unacked, queue_age]
|
||||
|
||||
@staticmethod
|
||||
def _get_oldest_message_age(
|
||||
redis_client: Redis, queue_name: str, now: float
|
||||
) -> float | None:
|
||||
"""Peek at the oldest (tail) message in a Redis list queue
|
||||
and extract its timestamp to compute age.
|
||||
|
||||
Note: If the Celery message contains neither ``properties.timestamp``
|
||||
nor ``headers.timestamp``, no age metric is emitted for this queue.
|
||||
This can happen with custom task producers or non-standard Celery
|
||||
protocol versions. The metric will simply be absent rather than
|
||||
inaccurate, which is the safest behavior for alerting.
|
||||
"""
|
||||
try:
|
||||
raw: bytes | str | None = redis_client.lindex(queue_name, -1) # type: ignore[assignment]
|
||||
if raw is None:
|
||||
return None
|
||||
msg = json.loads(raw)
|
||||
# Check for ETA tasks first — they are intentionally delayed,
|
||||
# so reporting their queue age would be misleading.
|
||||
headers = msg.get("headers", {})
|
||||
if headers.get("eta") is not None:
|
||||
return None
|
||||
# Celery v2 protocol: timestamp in properties
|
||||
props = msg.get("properties", {})
|
||||
ts = props.get("timestamp")
|
||||
if ts is not None:
|
||||
return now - float(ts)
|
||||
# Fallback: some Celery configurations place the timestamp in
|
||||
# headers instead of properties.
|
||||
ts = headers.get("timestamp")
|
||||
if ts is not None:
|
||||
return now - float(ts)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
class IndexAttemptCollector(_CachedCollector):
|
||||
"""Queries Postgres for index attempt state on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._configured: bool = False
|
||||
self._terminal_statuses: list = []
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Call once DB engine is initialized."""
|
||||
from onyx.db.enums import IndexingStatus
|
||||
|
||||
self._terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
self._configured = True
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if not self._configured:
|
||||
return []
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.index_attempt import get_active_index_attempts_for_metrics
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
attempts_gauge = GaugeMetricFamily(
|
||||
"onyx_index_attempts_active",
|
||||
"Number of non-terminal index attempts",
|
||||
labels=[
|
||||
"status",
|
||||
"source",
|
||||
"tenant_id",
|
||||
"connector_name",
|
||||
"cc_pair_id",
|
||||
],
|
||||
)
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tid in tenant_ids:
|
||||
# Defensive guard — get_all_tenant_ids() should never yield None,
|
||||
# but we guard here for API stability in case the contract changes.
|
||||
if tid is None:
|
||||
continue
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
rows = get_active_index_attempts_for_metrics(session)
|
||||
|
||||
for status, source, cc_id, cc_name, count in rows:
|
||||
name_val = cc_name or f"cc_pair_{cc_id}"
|
||||
attempts_gauge.add_metric(
|
||||
[
|
||||
status.value,
|
||||
source.value,
|
||||
tid,
|
||||
name_val,
|
||||
str(cc_id),
|
||||
],
|
||||
count,
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return [attempts_gauge]
|
||||
|
||||
|
||||
class ConnectorHealthCollector(_CachedCollector):
|
||||
"""Queries Postgres for connector health state on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._configured: bool = False
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Call once DB engine is initialized."""
|
||||
self._configured = True
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if not self._configured:
|
||||
return []
|
||||
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_health_for_metrics,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.index_attempt import get_docs_indexed_by_cc_pair
|
||||
from onyx.db.index_attempt import get_failed_attempt_counts_by_cc_pair
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
staleness_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_last_success_age_seconds",
|
||||
"Seconds since last successful index for this connector",
|
||||
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
|
||||
)
|
||||
error_state_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_in_error_state",
|
||||
"Whether the connector is in a repeated error state (1=yes, 0=no)",
|
||||
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
|
||||
)
|
||||
by_status_gauge = GaugeMetricFamily(
|
||||
"onyx_connectors_by_status",
|
||||
"Number of connectors grouped by status",
|
||||
labels=["tenant_id", "status"],
|
||||
)
|
||||
error_total_gauge = GaugeMetricFamily(
|
||||
"onyx_connectors_in_error_total",
|
||||
"Total number of connectors in repeated error state",
|
||||
labels=["tenant_id"],
|
||||
)
|
||||
per_connector_labels = [
|
||||
"tenant_id",
|
||||
"source",
|
||||
"cc_pair_id",
|
||||
"connector_name",
|
||||
]
|
||||
docs_success_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_docs_indexed",
|
||||
"Total new documents indexed (90-day rolling sum) per connector",
|
||||
labels=per_connector_labels,
|
||||
)
|
||||
docs_error_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_error_count",
|
||||
"Total number of failed index attempts per connector",
|
||||
labels=per_connector_labels,
|
||||
)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tid in tenant_ids:
|
||||
# Defensive guard — get_all_tenant_ids() should never yield None,
|
||||
# but we guard here for API stability in case the contract changes.
|
||||
if tid is None:
|
||||
continue
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
pairs = get_connector_health_for_metrics(session)
|
||||
error_counts_by_cc = get_failed_attempt_counts_by_cc_pair(session)
|
||||
docs_by_cc = get_docs_indexed_by_cc_pair(session)
|
||||
|
||||
status_counts: dict[str, int] = {}
|
||||
error_count = 0
|
||||
|
||||
for (
|
||||
cc_id,
|
||||
status,
|
||||
in_error,
|
||||
last_success,
|
||||
cc_name,
|
||||
source,
|
||||
) in pairs:
|
||||
cc_id_str = str(cc_id)
|
||||
source_val = source.value
|
||||
name_val = cc_name or f"cc_pair_{cc_id}"
|
||||
label_vals = [tid, source_val, cc_id_str, name_val]
|
||||
|
||||
if last_success is not None:
|
||||
# Both `now` and `last_success` are timezone-aware
|
||||
# (the DB column uses DateTime(timezone=True)),
|
||||
# so subtraction is safe.
|
||||
age = (now - last_success).total_seconds()
|
||||
staleness_gauge.add_metric(label_vals, age)
|
||||
|
||||
error_state_gauge.add_metric(
|
||||
label_vals,
|
||||
1.0 if in_error else 0.0,
|
||||
)
|
||||
if in_error:
|
||||
error_count += 1
|
||||
|
||||
docs_success_gauge.add_metric(
|
||||
label_vals,
|
||||
docs_by_cc.get(cc_id, 0),
|
||||
)
|
||||
|
||||
docs_error_gauge.add_metric(
|
||||
label_vals,
|
||||
error_counts_by_cc.get(cc_id, 0),
|
||||
)
|
||||
|
||||
status_val = status.value
|
||||
status_counts[status_val] = status_counts.get(status_val, 0) + 1
|
||||
|
||||
for status_val, count in status_counts.items():
|
||||
by_status_gauge.add_metric([tid, status_val], count)
|
||||
|
||||
error_total_gauge.add_metric([tid], error_count)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return [
|
||||
staleness_gauge,
|
||||
error_state_gauge,
|
||||
by_status_gauge,
|
||||
error_total_gauge,
|
||||
docs_success_gauge,
|
||||
docs_error_gauge,
|
||||
]
|
||||
|
||||
|
||||
class RedisHealthCollector(_CachedCollector):
|
||||
"""Collects Redis server health metrics (memory, clients, etc.)."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
|
||||
memory_used = GaugeMetricFamily(
|
||||
"onyx_redis_memory_used_bytes",
|
||||
"Redis used memory in bytes",
|
||||
)
|
||||
memory_peak = GaugeMetricFamily(
|
||||
"onyx_redis_memory_peak_bytes",
|
||||
"Redis peak used memory in bytes",
|
||||
)
|
||||
memory_frag = GaugeMetricFamily(
|
||||
"onyx_redis_memory_fragmentation_ratio",
|
||||
"Redis memory fragmentation ratio (>1.5 indicates fragmentation)",
|
||||
)
|
||||
connected_clients = GaugeMetricFamily(
|
||||
"onyx_redis_connected_clients",
|
||||
"Number of connected Redis clients",
|
||||
)
|
||||
|
||||
try:
|
||||
mem_info: dict = redis_client.info("memory") # type: ignore[assignment]
|
||||
memory_used.add_metric([], mem_info.get("used_memory", 0))
|
||||
memory_peak.add_metric([], mem_info.get("used_memory_peak", 0))
|
||||
frag = mem_info.get("mem_fragmentation_ratio")
|
||||
if frag is not None:
|
||||
memory_frag.add_metric([], frag)
|
||||
|
||||
client_info: dict = redis_client.info("clients") # type: ignore[assignment]
|
||||
connected_clients.add_metric([], client_info.get("connected_clients", 0))
|
||||
except Exception:
|
||||
logger.debug("Failed to collect Redis health metrics", exc_info=True)
|
||||
|
||||
return [memory_used, memory_peak, memory_frag, connected_clients]
|
||||
|
||||
|
||||
class WorkerHealthCollector(_CachedCollector):
|
||||
"""Collects Celery worker count and process count via inspect ping.
|
||||
|
||||
Uses a longer cache TTL (60s) since inspect.ping() is a broadcast
|
||||
command that takes a couple seconds to complete.
|
||||
|
||||
Maintains a set of known worker short-names so that when a worker
|
||||
stops responding, we emit ``up=0`` instead of silently dropping the
|
||||
metric (which would make ``absent()``-style alerts impossible).
|
||||
"""
|
||||
|
||||
# Remove a worker from _known_workers after this many consecutive
|
||||
# missed pings (at 60s TTL ≈ 10 minutes of being unreachable).
|
||||
_MAX_CONSECUTIVE_MISSES = 10
|
||||
|
||||
def __init__(self, cache_ttl: float = 60.0) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._celery_app: Any | None = None
|
||||
# worker short-name → consecutive miss count.
|
||||
# Workers start at 0 and reset to 0 each time they respond.
|
||||
# Removed after _MAX_CONSECUTIVE_MISSES missed collects.
|
||||
self._known_workers: dict[str, int] = {}
|
||||
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app instance for inspect commands."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
active_workers = GaugeMetricFamily(
|
||||
"onyx_celery_active_worker_count",
|
||||
"Number of active Celery workers responding to ping",
|
||||
)
|
||||
worker_up = GaugeMetricFamily(
|
||||
"onyx_celery_worker_up",
|
||||
"Whether a specific Celery worker is alive (1=up, 0=down)",
|
||||
labels=["worker"],
|
||||
)
|
||||
|
||||
try:
|
||||
inspector = self._celery_app.control.inspect(timeout=3.0)
|
||||
ping_result = inspector.ping()
|
||||
|
||||
responding: set[str] = set()
|
||||
if ping_result:
|
||||
active_workers.add_metric([], len(ping_result))
|
||||
for worker_name in ping_result:
|
||||
# Strip hostname suffix for cleaner labels
|
||||
short_name = worker_name.split("@")[0]
|
||||
responding.add(short_name)
|
||||
else:
|
||||
active_workers.add_metric([], 0)
|
||||
|
||||
# Register newly-seen workers and reset miss count for
|
||||
# workers that responded.
|
||||
for short_name in responding:
|
||||
self._known_workers[short_name] = 0
|
||||
|
||||
# Increment miss count for non-responding workers and evict
|
||||
# those that have been missing too long.
|
||||
stale = []
|
||||
for short_name in list(self._known_workers):
|
||||
if short_name not in responding:
|
||||
self._known_workers[short_name] += 1
|
||||
if self._known_workers[short_name] >= self._MAX_CONSECUTIVE_MISSES:
|
||||
stale.append(short_name)
|
||||
for short_name in stale:
|
||||
del self._known_workers[short_name]
|
||||
|
||||
for short_name in sorted(self._known_workers):
|
||||
worker_up.add_metric([short_name], 1 if short_name in responding else 0)
|
||||
except Exception:
|
||||
logger.debug("Failed to collect worker health metrics", exc_info=True)
|
||||
|
||||
return [active_workers, worker_up]
|
||||
@@ -1,113 +0,0 @@
|
||||
"""Setup function for indexing pipeline Prometheus collectors.
|
||||
|
||||
Called once by the monitoring celery worker after Redis and DB are ready.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from prometheus_client.registry import REGISTRY
|
||||
from redis import Redis
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Module-level singletons — these are lightweight objects (no connections or DB
|
||||
# state) until configure() / set_redis_factory() is called. Keeping them at
|
||||
# module level ensures they survive the lifetime of the worker process and are
|
||||
# only registered with the Prometheus registry once.
|
||||
_queue_collector = QueueDepthCollector()
|
||||
_attempt_collector = IndexAttemptCollector()
|
||||
_connector_collector = ConnectorHealthCollector()
|
||||
_redis_health_collector = RedisHealthCollector()
|
||||
_worker_health_collector = WorkerHealthCollector()
|
||||
|
||||
|
||||
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
|
||||
"""Create a factory that returns a cached broker Redis client.
|
||||
|
||||
Reuses a single connection across scrapes to avoid leaking connections.
|
||||
Reconnects automatically if the cached connection becomes stale.
|
||||
"""
|
||||
_cached_client: list[Redis | None] = [None]
|
||||
# Keep a reference to the Kombu Connection so we can close it on
|
||||
# reconnect (the raw Redis client outlives the Kombu wrapper).
|
||||
_cached_kombu_conn: list[Any] = [None]
|
||||
|
||||
def _close_client(client: Redis) -> None:
|
||||
"""Best-effort close of a Redis client."""
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close stale Redis client", exc_info=True)
|
||||
|
||||
def _close_kombu_conn() -> None:
|
||||
"""Best-effort close of the cached Kombu Connection."""
|
||||
conn = _cached_kombu_conn[0]
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close Kombu connection", exc_info=True)
|
||||
_cached_kombu_conn[0] = None
|
||||
|
||||
def _get_broker_redis() -> Redis:
|
||||
client = _cached_client[0]
|
||||
if client is not None:
|
||||
try:
|
||||
client.ping()
|
||||
return client
|
||||
except Exception:
|
||||
logger.debug("Cached Redis client stale, reconnecting")
|
||||
_close_client(client)
|
||||
_cached_client[0] = None
|
||||
_close_kombu_conn()
|
||||
|
||||
# Get a fresh Redis client from the broker connection.
|
||||
# We hold this client long-term (cached above) rather than using a
|
||||
# context manager, because we need it to persist across scrapes.
|
||||
# The caching logic above ensures we only ever hold one connection,
|
||||
# and we close it explicitly on reconnect.
|
||||
conn = celery_app.broker_connection()
|
||||
# kombu's Channel exposes .client at runtime (the underlying Redis
|
||||
# client) but the type stubs don't declare it.
|
||||
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
|
||||
_cached_client[0] = new_client
|
||||
_cached_kombu_conn[0] = conn
|
||||
return new_client
|
||||
|
||||
return _get_broker_redis
|
||||
|
||||
|
||||
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
"""Register all indexing pipeline collectors with the default registry.
|
||||
|
||||
Args:
|
||||
celery_app: The Celery application instance. Used to obtain a fresh
|
||||
broker Redis client on each scrape for queue depth metrics.
|
||||
"""
|
||||
redis_factory = _make_broker_redis_factory(celery_app)
|
||||
_queue_collector.set_redis_factory(redis_factory)
|
||||
_redis_health_collector.set_redis_factory(redis_factory)
|
||||
_worker_health_collector.set_celery_app(celery_app)
|
||||
_attempt_collector.configure()
|
||||
_connector_collector.configure()
|
||||
|
||||
for collector in (
|
||||
_queue_collector,
|
||||
_attempt_collector,
|
||||
_connector_collector,
|
||||
_redis_health_collector,
|
||||
_worker_health_collector,
|
||||
):
|
||||
try:
|
||||
REGISTRY.register(collector)
|
||||
except ValueError:
|
||||
logger.debug("Collector already registered: %s", type(collector).__name__)
|
||||
@@ -1,253 +0,0 @@
|
||||
"""Per-connector Prometheus metrics for indexing tasks.
|
||||
|
||||
Enriches the two primary indexing tasks (docfetching_proxy_task and
|
||||
docprocessing_task) with connector-level labels: source, tenant_id,
|
||||
and cc_pair_id.
|
||||
|
||||
Note: connector_name is intentionally excluded from push-based per-task
|
||||
counters because it is a user-defined free-form string that can create
|
||||
unbounded cardinality. The pull-based collectors on the monitoring worker
|
||||
(see indexing_pipeline.py) include connector_name since they have bounded
|
||||
cardinality (one series per connector, not per task execution).
|
||||
|
||||
Uses an in-memory cache for cc_pair_id → (source, name) lookups.
|
||||
Connectors never change source type, and names change rarely, so the
|
||||
cache is safe to hold for the worker's lifetime.
|
||||
|
||||
Usage in a worker app module:
|
||||
from onyx.server.metrics.indexing_task_metrics import (
|
||||
on_indexing_task_prerun,
|
||||
on_indexing_task_postrun,
|
||||
)
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from celery import Task
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.server.metrics.celery_task_metrics import _MAX_START_TIME_AGE_SECONDS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConnectorInfo:
|
||||
"""Cached connector metadata for metric labels."""
|
||||
|
||||
source: str
|
||||
name: str
|
||||
|
||||
|
||||
_UNKNOWN_CONNECTOR = ConnectorInfo(source="unknown", name="unknown")
|
||||
|
||||
# (tenant_id, cc_pair_id) → ConnectorInfo (populated on first encounter).
|
||||
# Keyed by tenant to avoid cross-tenant cache poisoning in multi-tenant
|
||||
# deployments where different tenants can share the same cc_pair_id value.
|
||||
_connector_cache: dict[tuple[str, int], ConnectorInfo] = {}
|
||||
|
||||
# Lock protecting _connector_cache — multiple thread-pool workers may
|
||||
# resolve connectors concurrently.
|
||||
_connector_cache_lock = threading.Lock()
|
||||
|
||||
# Only enrich these task types with per-connector labels
|
||||
_INDEXING_TASK_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
}
|
||||
)
|
||||
|
||||
# connector_name is intentionally excluded — see module docstring.
|
||||
INDEXING_TASK_STARTED = Counter(
|
||||
"onyx_indexing_task_started_total",
|
||||
"Indexing tasks started per connector",
|
||||
["task_name", "source", "tenant_id", "cc_pair_id"],
|
||||
)
|
||||
|
||||
INDEXING_TASK_COMPLETED = Counter(
|
||||
"onyx_indexing_task_completed_total",
|
||||
"Indexing tasks completed per connector",
|
||||
[
|
||||
"task_name",
|
||||
"source",
|
||||
"tenant_id",
|
||||
"cc_pair_id",
|
||||
"outcome",
|
||||
],
|
||||
)
|
||||
|
||||
INDEXING_TASK_DURATION = Histogram(
|
||||
"onyx_indexing_task_duration_seconds",
|
||||
"Indexing task duration by connector type",
|
||||
["task_name", "source", "tenant_id"],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
)
|
||||
|
||||
# task_id → monotonic start time (for indexing tasks only)
|
||||
_indexing_start_times: dict[str, float] = {}
|
||||
|
||||
# Lock protecting _indexing_start_times — prerun, postrun, and eviction may
|
||||
# run concurrently on thread-pool workers.
|
||||
_indexing_start_times_lock = threading.Lock()
|
||||
|
||||
|
||||
def _evict_stale_start_times() -> None:
|
||||
"""Remove _indexing_start_times entries older than _MAX_START_TIME_AGE_SECONDS.
|
||||
|
||||
Must be called while holding _indexing_start_times_lock.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
stale_ids = [
|
||||
tid
|
||||
for tid, start in _indexing_start_times.items()
|
||||
if now - start > _MAX_START_TIME_AGE_SECONDS
|
||||
]
|
||||
for tid in stale_ids:
|
||||
_indexing_start_times.pop(tid, None)
|
||||
|
||||
|
||||
def _resolve_connector(cc_pair_id: int) -> ConnectorInfo:
|
||||
"""Resolve cc_pair_id to ConnectorInfo, using cache when possible.
|
||||
|
||||
On cache miss, does a single DB query with eager connector load.
|
||||
On any failure, returns _UNKNOWN_CONNECTOR without caching, so that
|
||||
subsequent calls can retry the lookup once the DB is available.
|
||||
|
||||
Note on tenant_id source: we read CURRENT_TENANT_ID_CONTEXTVAR for the
|
||||
cache key. The Celery tenant-aware middleware sets this contextvar before
|
||||
task execution, and it always matches kwargs["tenant_id"] (which is set
|
||||
at task dispatch time). They are guaranteed to agree for a given task
|
||||
execution context.
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get("") or ""
|
||||
cache_key = (tenant_id, cc_pair_id)
|
||||
|
||||
with _connector_cache_lock:
|
||||
cached = _connector_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pair_from_id,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session,
|
||||
cc_pair_id,
|
||||
eager_load_connector=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
# DB lookup succeeded but cc_pair doesn't exist — don't cache,
|
||||
# it may appear later (race with connector creation).
|
||||
return _UNKNOWN_CONNECTOR
|
||||
|
||||
info = ConnectorInfo(
|
||||
source=cc_pair.connector.source.value,
|
||||
name=cc_pair.name,
|
||||
)
|
||||
with _connector_cache_lock:
|
||||
_connector_cache[cache_key] = info
|
||||
return info
|
||||
except Exception:
|
||||
logger.debug(
|
||||
f"Failed to resolve connector info for cc_pair_id={cc_pair_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
return _UNKNOWN_CONNECTOR
|
||||
|
||||
|
||||
def on_indexing_task_prerun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
kwargs: dict | None,
|
||||
) -> None:
|
||||
"""Record per-connector metrics at task start.
|
||||
|
||||
Only fires for tasks in _INDEXING_TASK_NAMES. Silently returns for
|
||||
all other tasks.
|
||||
"""
|
||||
if task is None or task_id is None or kwargs is None:
|
||||
return
|
||||
|
||||
task_name = task.name or ""
|
||||
if task_name not in _INDEXING_TASK_NAMES:
|
||||
return
|
||||
|
||||
try:
|
||||
cc_pair_id = kwargs.get("cc_pair_id")
|
||||
tenant_id = str(kwargs.get("tenant_id", "unknown"))
|
||||
|
||||
if cc_pair_id is None:
|
||||
return
|
||||
|
||||
info = _resolve_connector(cc_pair_id)
|
||||
|
||||
INDEXING_TASK_STARTED.labels(
|
||||
task_name=task_name,
|
||||
source=info.source,
|
||||
tenant_id=tenant_id,
|
||||
cc_pair_id=str(cc_pair_id),
|
||||
).inc()
|
||||
|
||||
with _indexing_start_times_lock:
|
||||
_evict_stale_start_times()
|
||||
_indexing_start_times[task_id] = time.monotonic()
|
||||
except Exception:
|
||||
logger.debug("Failed to record indexing task prerun metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_indexing_task_postrun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
kwargs: dict | None,
|
||||
state: str | None,
|
||||
) -> None:
|
||||
"""Record per-connector completion metrics.
|
||||
|
||||
Only fires for tasks in _INDEXING_TASK_NAMES.
|
||||
"""
|
||||
if task is None or task_id is None or kwargs is None:
|
||||
return
|
||||
|
||||
task_name = task.name or ""
|
||||
if task_name not in _INDEXING_TASK_NAMES:
|
||||
return
|
||||
|
||||
try:
|
||||
cc_pair_id = kwargs.get("cc_pair_id")
|
||||
tenant_id = str(kwargs.get("tenant_id", "unknown"))
|
||||
|
||||
if cc_pair_id is None:
|
||||
return
|
||||
|
||||
info = _resolve_connector(cc_pair_id)
|
||||
outcome = "success" if state == "SUCCESS" else "failure"
|
||||
|
||||
INDEXING_TASK_COMPLETED.labels(
|
||||
task_name=task_name,
|
||||
source=info.source,
|
||||
tenant_id=tenant_id,
|
||||
cc_pair_id=str(cc_pair_id),
|
||||
outcome=outcome,
|
||||
).inc()
|
||||
|
||||
with _indexing_start_times_lock:
|
||||
start = _indexing_start_times.pop(task_id, None)
|
||||
if start is not None:
|
||||
INDEXING_TASK_DURATION.labels(
|
||||
task_name=task_name,
|
||||
source=info.source,
|
||||
tenant_id=tenant_id,
|
||||
).observe(time.monotonic() - start)
|
||||
except Exception:
|
||||
logger.debug("Failed to record indexing task postrun metrics", exc_info=True)
|
||||
@@ -1,89 +0,0 @@
|
||||
"""Standalone Prometheus metrics HTTP server for non-API processes.
|
||||
|
||||
The FastAPI API server already exposes /metrics via prometheus-fastapi-instrumentator.
|
||||
Celery workers and other background processes use this module to expose their
|
||||
own /metrics endpoint on a configurable port.
|
||||
|
||||
Usage:
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
start_metrics_server("monitoring") # reads port from env or uses default
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Default ports for worker types that serve custom Prometheus metrics.
|
||||
# Only add entries here when a worker actually registers collectors.
|
||||
# In k8s each worker type runs in its own pod, so PROMETHEUS_METRICS_PORT
|
||||
# env var can override.
|
||||
_DEFAULT_PORTS: dict[str, int] = {
|
||||
"monitoring": 9096,
|
||||
"docfetching": 9092,
|
||||
"docprocessing": 9093,
|
||||
}
|
||||
|
||||
_server_started = False
|
||||
_server_lock = threading.Lock()
|
||||
|
||||
|
||||
def start_metrics_server(worker_type: str) -> int | None:
|
||||
"""Start a Prometheus metrics HTTP server in a background thread.
|
||||
|
||||
Returns the port if started, None if disabled or already started.
|
||||
|
||||
Port resolution order:
|
||||
1. PROMETHEUS_METRICS_PORT env var (explicit override)
|
||||
2. Default port for the worker type
|
||||
3. If worker type is unknown and no env var, skip
|
||||
|
||||
Set PROMETHEUS_METRICS_ENABLED=false to disable.
|
||||
"""
|
||||
global _server_started
|
||||
|
||||
with _server_lock:
|
||||
if _server_started:
|
||||
logger.debug(f"Metrics server already started for {worker_type}")
|
||||
return None
|
||||
|
||||
enabled = os.environ.get("PROMETHEUS_METRICS_ENABLED", "true").lower()
|
||||
if enabled in ("false", "0", "no"):
|
||||
logger.info(f"Prometheus metrics server disabled for {worker_type}")
|
||||
return None
|
||||
|
||||
port_str = os.environ.get("PROMETHEUS_METRICS_PORT")
|
||||
if port_str:
|
||||
try:
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid PROMETHEUS_METRICS_PORT '{port_str}' for {worker_type}, "
|
||||
"must be a numeric port. Skipping metrics server."
|
||||
)
|
||||
return None
|
||||
elif worker_type in _DEFAULT_PORTS:
|
||||
port = _DEFAULT_PORTS[worker_type]
|
||||
else:
|
||||
logger.info(
|
||||
f"No default metrics port for worker type '{worker_type}' "
|
||||
"and PROMETHEUS_METRICS_PORT not set. Skipping metrics server."
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
start_http_server(port)
|
||||
_server_started = True
|
||||
logger.info(
|
||||
f"Prometheus metrics server started on :{port} for {worker_type}"
|
||||
)
|
||||
return port
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
f"Failed to start metrics server on :{port} for {worker_type}: {e}"
|
||||
)
|
||||
return None
|
||||
@@ -29,6 +29,7 @@ from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.process_message import run_multi_model_stream
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -81,6 +83,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -570,6 +573,38 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in run_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=chat_message_req.llm_overrides, # type: ignore[arg-type]
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -660,6 +695,26 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
_user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
try:
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
@@ -41,6 +41,16 @@ class MessageResponseIDInfo(BaseModel):
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class MultiModelMessageResponseIDInfo(BaseModel):
|
||||
"""Sent at the start of a multi-model streaming response.
|
||||
Contains the user message ID and the reserved assistant message IDs
|
||||
for each model being run in parallel."""
|
||||
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_ids: list[int]
|
||||
model_names: list[str]
|
||||
|
||||
|
||||
class SourceTag(Tag):
|
||||
source: DocumentSource
|
||||
|
||||
@@ -86,6 +96,9 @@ class SendMessageRequest(BaseModel):
|
||||
message: str
|
||||
|
||||
llm_override: LLMOverride | None = None
|
||||
# For multi-model mode: up to 3 LLM overrides to run in parallel.
|
||||
# When provided with >1 entry, triggers multi-model streaming.
|
||||
llm_overrides: list[LLMOverride] | None = None
|
||||
# Test-only override for deterministic LiteLLM mock responses.
|
||||
mock_llm_response: str | None = None
|
||||
|
||||
@@ -211,6 +224,8 @@ class ChatMessageDetail(BaseModel):
|
||||
error: str | None = None
|
||||
current_feedback: str | None = None # "like" | "dislike" | null
|
||||
processing_duration_seconds: float | None = None
|
||||
preferred_response_id: int | None = None
|
||||
model_display_name: str | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
@@ -218,6 +233,11 @@ class ChatMessageDetail(BaseModel):
|
||||
return initial_dict
|
||||
|
||||
|
||||
class SetPreferredResponseRequest(BaseModel):
|
||||
user_message_id: int
|
||||
preferred_response_id: int
|
||||
|
||||
|
||||
class ChatSessionDetailResponse(BaseModel):
|
||||
chat_session_id: UUID
|
||||
description: str | None
|
||||
|
||||
@@ -8,3 +8,5 @@ class Placement(BaseModel):
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
|
||||
model_index: int | None = None
|
||||
|
||||
@@ -17,7 +17,6 @@ from onyx.db.models import User
|
||||
from onyx.db.notification import dismiss_all_notifications
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.db.notification import update_notification_last_shown
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.features.build.utils import is_onyx_craft_enabled
|
||||
@@ -81,7 +80,6 @@ def fetch_settings(
|
||||
needs_reindexing=needs_reindexing,
|
||||
onyx_craft_enabled=onyx_craft_enabled_for_user,
|
||||
vector_db_enabled=not DISABLE_VECTOR_DB,
|
||||
hooks_enabled=HOOKS_AVAILABLE,
|
||||
version=onyx_version,
|
||||
)
|
||||
|
||||
|
||||
@@ -104,7 +104,5 @@ class UserSettings(Settings):
|
||||
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
|
||||
# document sets are unavailable.
|
||||
vector_db_enabled: bool = True
|
||||
# True when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
hooks_enabled: bool = False
|
||||
# Application version, read from the ONYX_VERSION env var at startup.
|
||||
version: str | None = None
|
||||
|
||||
@@ -736,7 +736,7 @@ if __name__ == "__main__":
|
||||
llm.config.model_name, llm.config.model_provider
|
||||
)
|
||||
|
||||
persona = get_default_behavior_persona(db_session, eager_load_for_tools=True)
|
||||
persona = get_default_behavior_persona(db_session)
|
||||
if persona is None:
|
||||
raise ValueError("No default persona found")
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.mcp import get_all_mcp_tools_for_server
|
||||
@@ -125,12 +124,7 @@ def construct_tools(
|
||||
) -> dict[int, list[Tool]]:
|
||||
"""Constructs tools based on persona configuration and available APIs.
|
||||
|
||||
Will simply skip tools that are not allowed/available.
|
||||
|
||||
Callers must supply a persona with ``tools``, ``document_sets``,
|
||||
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
|
||||
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
|
||||
to avoid lazy SQL queries after the session may have been flushed."""
|
||||
Will simply skip tools that are not allowed/available."""
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
# Log which tools are attached to the persona for debugging
|
||||
@@ -149,28 +143,6 @@ def construct_tools(
|
||||
# This flow is for search so we do not get all indices.
|
||||
document_index = get_default_document_index(search_settings, None, db_session)
|
||||
|
||||
def _build_search_tool(tool_id: int, config: SearchToolConfig) -> SearchTool:
|
||||
persona_search_info = PersonaSearchInfo(
|
||||
document_set_names=[ds.name for ds in persona.document_sets],
|
||||
search_start_date=persona.search_start_date,
|
||||
attached_document_ids=[doc.id for doc in persona.attached_documents],
|
||||
hierarchy_node_ids=[node.id for node in persona.hierarchy_nodes],
|
||||
)
|
||||
return SearchTool(
|
||||
tool_id=tool_id,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona_search_info=persona_search_info,
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=config.user_selected_filters,
|
||||
project_id_filter=config.project_id_filter,
|
||||
persona_id_filter=config.persona_id_filter,
|
||||
bypass_acl=config.bypass_acl,
|
||||
slack_context=config.slack_context,
|
||||
enable_slack_search=config.enable_slack_search,
|
||||
)
|
||||
|
||||
added_search_tool = False
|
||||
for db_tool_model in persona.tools:
|
||||
# If allowed_tool_ids is specified, skip tools not in the allowed list
|
||||
@@ -204,9 +176,22 @@ def construct_tools(
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
tool_dict[db_tool_model.id] = [
|
||||
_build_search_tool(db_tool_model.id, search_tool_config)
|
||||
]
|
||||
search_tool = SearchTool(
|
||||
tool_id=db_tool_model.id,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id_filter=search_tool_config.project_id_filter,
|
||||
persona_id_filter=search_tool_config.persona_id_filter,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
)
|
||||
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
|
||||
# Handle Image Generation Tool
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
@@ -436,12 +421,26 @@ def construct_tools(
|
||||
# Get the database tool model for SearchTool
|
||||
search_tool_db_model = get_builtin_tool(db_session, SearchTool)
|
||||
|
||||
# Use the passed-in config if available, otherwise create a new one
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
tool_dict[search_tool_db_model.id] = [
|
||||
_build_search_tool(search_tool_db_model.id, search_tool_config)
|
||||
]
|
||||
search_tool = SearchTool(
|
||||
tool_id=search_tool_db_model.id,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id_filter=search_tool_config.project_id_filter,
|
||||
persona_id_filter=search_tool_config.persona_id_filter,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
)
|
||||
|
||||
tool_dict[search_tool_db_model.id] = [search_tool]
|
||||
|
||||
# Always inject MemoryTool when the user has the memory tool enabled,
|
||||
# bypassing persona tool associations and allowed_tool_ids filtering
|
||||
|
||||
@@ -51,7 +51,6 @@ from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.context.search.pipeline import merge_individual_chunks
|
||||
from onyx.context.search.pipeline import search_pipeline
|
||||
@@ -66,6 +65,7 @@ from onyx.db.federated import (
|
||||
get_federated_connector_document_set_mappings_by_document_set_names,
|
||||
)
|
||||
from onyx.db.federated import list_federated_connector_oauth_tokens
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -238,8 +238,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
emitter: Emitter,
|
||||
# Used for ACLs and federated search, anonymous users only see public docs
|
||||
user: User,
|
||||
# Pre-extracted persona search configuration
|
||||
persona_search_info: PersonaSearchInfo,
|
||||
# Used for filter settings
|
||||
persona: Persona,
|
||||
llm: LLM,
|
||||
document_index: DocumentIndex,
|
||||
# Respecting user selections
|
||||
@@ -258,7 +258,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
super().__init__(emitter=emitter)
|
||||
|
||||
self.user = user
|
||||
self.persona_search_info = persona_search_info
|
||||
self.persona = persona
|
||||
self.llm = llm
|
||||
self.document_index = document_index
|
||||
self.user_selected_filters = user_selected_filters
|
||||
@@ -289,7 +289,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
# Case 1: Slack bot context — requires a Slack federated connector
|
||||
# linked via the persona's document sets
|
||||
if self.slack_context:
|
||||
document_set_names = self.persona_search_info.document_set_names
|
||||
document_set_names = [ds.name for ds in self.persona.document_sets]
|
||||
if not document_set_names:
|
||||
logger.debug(
|
||||
"Skipping Slack federated search: no document sets on persona"
|
||||
@@ -463,7 +463,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
persona_id_filter=self.persona_id_filter,
|
||||
document_index=self.document_index,
|
||||
user=self.user,
|
||||
persona_search_info=self.persona_search_info,
|
||||
persona=self.persona,
|
||||
acl_filters=acl_filters,
|
||||
embedding_model=embedding_model,
|
||||
prefetched_federated_retrieval_infos=federated_retrieval_infos,
|
||||
@@ -587,12 +587,15 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
and self.user_selected_filters.source_type
|
||||
else None
|
||||
)
|
||||
persona_document_sets = (
|
||||
[ds.name for ds in self.persona.document_sets] if self.persona else None
|
||||
)
|
||||
federated_retrieval_infos = (
|
||||
get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=self.user.id if self.user else None,
|
||||
source_types=prefetch_source_types,
|
||||
document_set_names=self.persona_search_info.document_set_names,
|
||||
document_set_names=persona_document_sets,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import UserProject
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
|
||||
|
||||
def test_eager_load_persona_loads_relationships(db_session: Session) -> None:
|
||||
"""Verify that eager_load_persona pre-loads persona, its collections, and project."""
|
||||
user = create_test_user(db_session, "eager-load")
|
||||
persona = Persona(name="eager-load-test", description="test")
|
||||
project = UserProject(name="eager-load-project", user_id=user.id)
|
||||
db_session.add_all([persona, project])
|
||||
db_session.flush()
|
||||
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="test",
|
||||
user_id=None,
|
||||
persona_id=persona.id,
|
||||
project_id=project.id,
|
||||
)
|
||||
|
||||
loaded = get_chat_session_by_id(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
|
||||
try:
|
||||
tmp = inspect(loaded)
|
||||
assert tmp is not None
|
||||
unloaded = tmp.unloaded
|
||||
assert "persona" not in unloaded
|
||||
assert "project" not in unloaded
|
||||
|
||||
tmp = inspect(loaded.persona)
|
||||
assert tmp is not None
|
||||
persona_unloaded = tmp.unloaded
|
||||
assert "tools" not in persona_unloaded
|
||||
assert "user_files" not in persona_unloaded
|
||||
assert "document_sets" not in persona_unloaded
|
||||
assert "attached_documents" not in persona_unloaded
|
||||
assert "hierarchy_nodes" not in persona_unloaded
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -11,8 +11,8 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
@@ -139,12 +139,12 @@ def use_mock_search_pipeline(
|
||||
chunk_search_request: ChunkSearchRequest,
|
||||
document_index: DocumentIndex, # noqa: ARG001
|
||||
user: User | None, # noqa: ARG001
|
||||
persona_search_info: PersonaSearchInfo | None, # noqa: ARG001
|
||||
persona: Persona | None, # noqa: ARG001
|
||||
db_session: Session | None = None, # noqa: ARG001
|
||||
auto_detect_filters: bool = False, # noqa: ARG001
|
||||
llm: LLM | None = None, # noqa: ARG001
|
||||
project_id_filter: int | None = None, # noqa: ARG001
|
||||
persona_id_filter: int | None = None, # noqa: ARG001
|
||||
project_id: int | None = None, # noqa: ARG001
|
||||
persona_id: int | None = None, # noqa: ARG001
|
||||
# Pre-fetched data (used by SearchTool to avoid DB access in parallel calls)
|
||||
acl_filters: list[str] | None = None, # noqa: ARG001
|
||||
embedding_model: EmbeddingModel | None = None, # noqa: ARG001
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
"""Tests for user group rename DB operation."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.db.user_group import rename_user_group
|
||||
from onyx.db.models import UserGroup
|
||||
|
||||
|
||||
class TestRenameUserGroup:
|
||||
"""Tests for rename_user_group function."""
|
||||
|
||||
@patch("ee.onyx.db.user_group.DISABLE_VECTOR_DB", False)
|
||||
@patch(
|
||||
"ee.onyx.db.user_group._mark_user_group__cc_pair_relationships_outdated__no_commit"
|
||||
)
|
||||
def test_rename_succeeds_and_triggers_sync(
|
||||
self, mock_mark_outdated: MagicMock
|
||||
) -> None:
|
||||
mock_session = MagicMock()
|
||||
mock_group = MagicMock(spec=UserGroup)
|
||||
mock_group.name = "Old Name"
|
||||
mock_group.is_up_to_date = True
|
||||
mock_session.scalar.return_value = mock_group
|
||||
|
||||
result = rename_user_group(mock_session, user_group_id=1, new_name="New Name")
|
||||
|
||||
assert result.name == "New Name"
|
||||
assert result.is_up_to_date is False
|
||||
mock_mark_outdated.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_rename_group_not_found(self) -> None:
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
rename_user_group(mock_session, user_group_id=999, new_name="New Name")
|
||||
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
def test_rename_group_syncing_raises(self) -> None:
|
||||
mock_session = MagicMock()
|
||||
mock_group = MagicMock(spec=UserGroup)
|
||||
mock_group.is_up_to_date = False
|
||||
mock_session.scalar.return_value = mock_group
|
||||
|
||||
with pytest.raises(ValueError, match="currently syncing"):
|
||||
rename_user_group(mock_session, user_group_id=1, new_name="New Name")
|
||||
|
||||
mock_session.commit.assert_not_called()
|
||||
@@ -1,216 +0,0 @@
|
||||
"""
|
||||
Unit tests for the check_available_tenants task.
|
||||
|
||||
Tests verify:
|
||||
- Provisioning loop calls pre_provision_tenant the correct number of times
|
||||
- Batch size is capped at _MAX_TENANTS_PER_RUN
|
||||
- A failure in one provisioning call does not stop subsequent calls
|
||||
- No provisioning happens when pool is already full
|
||||
- TARGET_AVAILABLE_TENANTS is respected
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.background.celery.tasks.tenant_provisioning.tasks import (
|
||||
_MAX_TENANTS_PER_RUN,
|
||||
)
|
||||
from ee.onyx.background.celery.tasks.tenant_provisioning.tasks import (
|
||||
check_available_tenants,
|
||||
)
|
||||
|
||||
# Access the underlying function directly, bypassing Celery's task wrapper
|
||||
# which injects `self` as the first argument when bind=True.
|
||||
_check_available_tenants = check_available_tenants.run
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _enable_multi_tenant(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.MULTI_TENANT",
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_redis(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.acquire.return_value = True
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.lock.return_value = mock_lock
|
||||
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.get_redis_client",
|
||||
lambda tenant_id: mock_client, # noqa: ARG005
|
||||
)
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_pre_provision(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||
mock = MagicMock(return_value=True)
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.pre_provision_tenant",
|
||||
mock,
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
def _mock_available_count(monkeypatch: pytest.MonkeyPatch, count: int) -> None:
|
||||
"""Set up the DB session mock to return a specific available tenant count."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=False)
|
||||
mock_session.query.return_value.count.return_value = count
|
||||
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.get_session_with_shared_schema",
|
||||
lambda: mock_session,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_multi_tenant", "mock_redis")
|
||||
class TestCheckAvailableTenants:
|
||||
def test_provisions_all_needed_tenants(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""When pool has 2 and target is 5, should provision 3."""
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
|
||||
5,
|
||||
)
|
||||
_mock_available_count(monkeypatch, 2)
|
||||
|
||||
_check_available_tenants()
|
||||
|
||||
assert mock_pre_provision.call_count == 3
|
||||
|
||||
def test_batch_capped_at_max_per_run(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""When pool needs more than _MAX_TENANTS_PER_RUN, cap the batch."""
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
|
||||
20,
|
||||
)
|
||||
_mock_available_count(monkeypatch, 0)
|
||||
|
||||
_check_available_tenants()
|
||||
|
||||
assert mock_pre_provision.call_count == _MAX_TENANTS_PER_RUN
|
||||
|
||||
def test_no_provisioning_when_pool_full(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""When pool already meets target, should not provision anything."""
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
|
||||
5,
|
||||
)
|
||||
_mock_available_count(monkeypatch, 5)
|
||||
|
||||
_check_available_tenants()
|
||||
|
||||
assert mock_pre_provision.call_count == 0
|
||||
|
||||
def test_no_provisioning_when_pool_exceeds_target(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""When pool exceeds target, should not provision anything."""
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
|
||||
5,
|
||||
)
|
||||
_mock_available_count(monkeypatch, 8)
|
||||
|
||||
_check_available_tenants()
|
||||
|
||||
assert mock_pre_provision.call_count == 0
|
||||
|
||||
def test_failure_does_not_stop_remaining(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""If one provisioning fails, the rest should still be attempted."""
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
|
||||
5,
|
||||
)
|
||||
_mock_available_count(monkeypatch, 0)
|
||||
|
||||
# Fail on calls 2 and 4 (1-indexed)
|
||||
call_count = 0
|
||||
|
||||
def side_effect() -> bool:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count in (2, 4):
|
||||
raise RuntimeError("provisioning failed")
|
||||
return True
|
||||
|
||||
mock_pre_provision.side_effect = side_effect
|
||||
|
||||
_check_available_tenants()
|
||||
|
||||
# All 5 should be attempted despite 2 failures
|
||||
assert mock_pre_provision.call_count == 5
|
||||
|
||||
def test_skips_when_not_multi_tenant(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""Should not provision when multi-tenancy is disabled."""
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.MULTI_TENANT",
|
||||
False,
|
||||
)
|
||||
|
||||
_check_available_tenants()
|
||||
|
||||
assert mock_pre_provision.call_count == 0
|
||||
|
||||
def test_skips_when_lock_not_acquired(
|
||||
self,
|
||||
mock_redis: MagicMock,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""Should skip when another instance holds the lock."""
|
||||
mock_redis.lock.return_value.acquire.return_value = False
|
||||
|
||||
_check_available_tenants()
|
||||
|
||||
assert mock_pre_provision.call_count == 0
|
||||
|
||||
def test_lock_release_failure_does_not_raise(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_redis: MagicMock,
|
||||
mock_pre_provision: MagicMock,
|
||||
) -> None:
|
||||
"""LockNotOwnedError on release should be caught, not propagated."""
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
|
||||
monkeypatch.setattr(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
|
||||
5,
|
||||
)
|
||||
_mock_available_count(monkeypatch, 4)
|
||||
|
||||
mock_redis.lock.return_value.release.side_effect = LockNotOwnedError("expired")
|
||||
|
||||
# Should not raise
|
||||
_check_available_tenants()
|
||||
|
||||
assert mock_pre_provision.call_count == 1
|
||||
206
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
206
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""Unit tests for multi-model streaming validation and DB helpers.
|
||||
|
||||
These are pure unit tests — no real database or LLM calls required.
|
||||
The validation logic in run_multi_model_stream fires before any external
|
||||
calls, so we can trigger it with lightweight mocks.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(**kwargs: Any) -> SendMessageRequest:
|
||||
defaults: dict[str, Any] = {
|
||||
"message": "hello",
|
||||
"chat_session_id": uuid4(),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SendMessageRequest(**defaults)
|
||||
|
||||
|
||||
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
|
||||
return LLMOverride(model_provider=provider, model_version=version)
|
||||
|
||||
|
||||
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
|
||||
"""Advance the generator one step to trigger early validation."""
|
||||
from onyx.chat.process_message import run_multi_model_stream
|
||||
|
||||
user = MagicMock()
|
||||
user.is_anonymous = False
|
||||
user.email = "test@example.com"
|
||||
db = MagicMock()
|
||||
|
||||
gen = run_multi_model_stream(req, user, db, overrides)
|
||||
# Calling next() executes until the first yield OR raises.
|
||||
# Validation errors are raised before any yield.
|
||||
next(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_multi_model_stream — validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunMultiModelStreamValidation:
|
||||
def test_single_override_raises(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — must raise."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [_make_override()])
|
||||
|
||||
def test_four_overrides_raises(self) -> None:
|
||||
"""4 overrides exceeds maximum — must raise."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_zero_overrides_raises(self) -> None:
|
||||
"""Empty override list raises."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [])
|
||||
|
||||
def test_deep_research_raises(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model."""
|
||||
req = _make_request(deep_research=True)
|
||||
with pytest.raises(ValueError, match="not supported"):
|
||||
_start_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
|
||||
def test_exactly_two_overrides_is_minimum(self) -> None:
|
||||
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
|
||||
req = _make_request()
|
||||
# 1 override must fail
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [_make_override()])
|
||||
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
|
||||
try:
|
||||
_start_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
except ValueError as exc:
|
||||
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
|
||||
except Exception:
|
||||
pass # Any other error means validation passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_preferred_response — validation (mocked db)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetPreferredResponseValidation:
|
||||
def test_user_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
db.query.return_value.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=999, preferred_assistant_message_id=1
|
||||
)
|
||||
|
||||
def test_wrong_message_type(self) -> None:
|
||||
"""Cannot set preferred response on a non-USER message."""
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.ASSISTANT # wrong type
|
||||
|
||||
db.query.return_value.get.return_value = user_msg
|
||||
|
||||
with pytest.raises(ValueError, match="not a user message"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
# First call returns user_msg, second call (for assistant) returns None
|
||||
db.query.return_value.get.side_effect = [user_msg, None]
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_not_child_of_user(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 999 # different parent
|
||||
|
||||
db.query.return_value.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
with pytest.raises(ValueError, match="not a child"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_valid_call_sets_preferred_response_id(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 1 # correct parent
|
||||
|
||||
db.query.return_value.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
|
||||
|
||||
assert user_msg.preferred_response_id == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMOverride — display_name field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMOverrideDisplayName:
|
||||
def test_display_name_defaults_none(self) -> None:
|
||||
override = LLMOverride(model_provider="openai", model_version="gpt-4")
|
||||
assert override.display_name is None
|
||||
|
||||
def test_display_name_set(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="openai",
|
||||
model_version="gpt-4",
|
||||
display_name="GPT-4 Turbo",
|
||||
)
|
||||
assert override.display_name == "GPT-4 Turbo"
|
||||
|
||||
def test_display_name_serializes(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="anthropic",
|
||||
model_version="claude-opus-4-6",
|
||||
display_name="Claude Opus",
|
||||
)
|
||||
d = override.model_dump()
|
||||
assert d["display_name"] == "Claude Opus"
|
||||
134
backend/tests/unit/onyx/chat/test_multi_model_types.py
Normal file
134
backend/tests/unit/onyx/chat/test_multi_model_types.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Unit tests for multi-model answer generation types.
|
||||
|
||||
Tests cover:
|
||||
- Placement.model_index serialization
|
||||
- MultiModelMessageResponseIDInfo round-trip
|
||||
- SendMessageRequest.llm_overrides backward compatibility
|
||||
- ChatMessageDetail new fields
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
|
||||
|
||||
class TestPlacementModelIndex:
|
||||
def test_default_none(self) -> None:
|
||||
p = Placement(turn_index=0)
|
||||
assert p.model_index is None
|
||||
|
||||
def test_set_value(self) -> None:
|
||||
p = Placement(turn_index=0, model_index=2)
|
||||
assert p.model_index == 2
|
||||
|
||||
def test_serializes(self) -> None:
|
||||
p = Placement(turn_index=0, tab_index=1, model_index=1)
|
||||
d = p.model_dump()
|
||||
assert d["model_index"] == 1
|
||||
|
||||
def test_none_excluded_when_default(self) -> None:
|
||||
p = Placement(turn_index=0)
|
||||
d = p.model_dump()
|
||||
assert d["model_index"] is None
|
||||
|
||||
|
||||
class TestMultiModelMessageResponseIDInfo:
|
||||
def test_round_trip(self) -> None:
|
||||
info = MultiModelMessageResponseIDInfo(
|
||||
user_message_id=42,
|
||||
reserved_assistant_message_ids=[43, 44, 45],
|
||||
model_names=["gpt-4", "claude-opus", "gemini-pro"],
|
||||
)
|
||||
d = info.model_dump()
|
||||
restored = MultiModelMessageResponseIDInfo(**d)
|
||||
assert restored.user_message_id == 42
|
||||
assert restored.reserved_assistant_message_ids == [43, 44, 45]
|
||||
assert restored.model_names == ["gpt-4", "claude-opus", "gemini-pro"]
|
||||
|
||||
def test_null_user_message_id(self) -> None:
|
||||
info = MultiModelMessageResponseIDInfo(
|
||||
user_message_id=None,
|
||||
reserved_assistant_message_ids=[1, 2],
|
||||
model_names=["a", "b"],
|
||||
)
|
||||
assert info.user_message_id is None
|
||||
|
||||
|
||||
class TestSendMessageRequestOverrides:
|
||||
def test_llm_overrides_default_none(self) -> None:
|
||||
req = SendMessageRequest(
|
||||
message="hello",
|
||||
chat_session_id=uuid4(),
|
||||
)
|
||||
assert req.llm_overrides is None
|
||||
|
||||
def test_llm_overrides_accepts_list(self) -> None:
|
||||
overrides = [
|
||||
LLMOverride(model_provider="openai", model_version="gpt-4"),
|
||||
LLMOverride(model_provider="anthropic", model_version="claude-opus"),
|
||||
]
|
||||
req = SendMessageRequest(
|
||||
message="hello",
|
||||
chat_session_id=uuid4(),
|
||||
llm_overrides=overrides,
|
||||
)
|
||||
assert req.llm_overrides is not None
|
||||
assert len(req.llm_overrides) == 2
|
||||
|
||||
def test_backward_compat_single_override(self) -> None:
|
||||
req = SendMessageRequest(
|
||||
message="hello",
|
||||
chat_session_id=uuid4(),
|
||||
llm_override=LLMOverride(model_provider="openai", model_version="gpt-4"),
|
||||
)
|
||||
assert req.llm_override is not None
|
||||
assert req.llm_overrides is None
|
||||
|
||||
|
||||
class TestChatMessageDetailMultiModel:
|
||||
def test_defaults_none(self) -> None:
|
||||
from onyx.configs.constants import MessageType
|
||||
|
||||
detail = ChatMessageDetail(
|
||||
message_id=1,
|
||||
message="hello",
|
||||
message_type=MessageType.ASSISTANT,
|
||||
time_sent="2026-03-22T00:00:00Z",
|
||||
files=[],
|
||||
)
|
||||
assert detail.preferred_response_id is None
|
||||
assert detail.model_display_name is None
|
||||
|
||||
def test_set_values(self) -> None:
|
||||
from onyx.configs.constants import MessageType
|
||||
|
||||
detail = ChatMessageDetail(
|
||||
message_id=1,
|
||||
message="hello",
|
||||
message_type=MessageType.USER,
|
||||
time_sent="2026-03-22T00:00:00Z",
|
||||
files=[],
|
||||
preferred_response_id=42,
|
||||
model_display_name="GPT-4",
|
||||
)
|
||||
assert detail.preferred_response_id == 42
|
||||
assert detail.model_display_name == "GPT-4"
|
||||
|
||||
def test_serializes(self) -> None:
|
||||
from onyx.configs.constants import MessageType
|
||||
|
||||
detail = ChatMessageDetail(
|
||||
message_id=1,
|
||||
message="hello",
|
||||
message_type=MessageType.ASSISTANT,
|
||||
time_sent="2026-03-22T00:00:00Z",
|
||||
files=[],
|
||||
model_display_name="Claude Opus",
|
||||
)
|
||||
d = detail.model_dump()
|
||||
assert d["model_display_name"] == "Claude Opus"
|
||||
assert d["preferred_response_id"] is None
|
||||
@@ -1,12 +1,4 @@
|
||||
import pytest
|
||||
|
||||
from onyx.chat.process_message import _resolve_query_processing_hook_result
|
||||
from onyx.chat.process_message import remove_answer_citations
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.hooks.points.query_processing import QueryProcessingResponse
|
||||
|
||||
|
||||
def test_remove_answer_citations_strips_http_markdown_citation() -> None:
|
||||
@@ -40,81 +32,3 @@ def test_remove_answer_citations_preserves_non_citation_markdown_links() -> None
|
||||
remove_answer_citations(answer)
|
||||
== "See [reference](https://example.com/Function_(mathematics)) for context."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Query Processing hook response handling (_resolve_query_processing_hook_result)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_hook_skipped_leaves_message_text_unchanged() -> None:
|
||||
result = _resolve_query_processing_hook_result(HookSkipped(), "original query")
|
||||
assert result == "original query"
|
||||
|
||||
|
||||
def test_hook_soft_failed_leaves_message_text_unchanged() -> None:
|
||||
result = _resolve_query_processing_hook_result(HookSoftFailed(), "original query")
|
||||
assert result == "original query"
|
||||
|
||||
|
||||
def test_null_query_raises_query_rejected() -> None:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_resolve_query_processing_hook_result(
|
||||
QueryProcessingResponse(query=None), "original query"
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
|
||||
|
||||
|
||||
def test_empty_string_query_raises_query_rejected() -> None:
|
||||
"""Empty string is falsy — must be treated as rejection, same as None."""
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_resolve_query_processing_hook_result(
|
||||
QueryProcessingResponse(query=""), "original query"
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
|
||||
|
||||
|
||||
def test_whitespace_only_query_raises_query_rejected() -> None:
|
||||
"""Whitespace-only string is truthy but meaningless — must be treated as rejection."""
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_resolve_query_processing_hook_result(
|
||||
QueryProcessingResponse(query=" "), "original query"
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
|
||||
|
||||
|
||||
def test_absent_query_field_raises_query_rejected() -> None:
|
||||
"""query defaults to None when not provided."""
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_resolve_query_processing_hook_result(
|
||||
QueryProcessingResponse(), "original query"
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
|
||||
|
||||
|
||||
def test_rejection_message_surfaced_in_error_when_provided() -> None:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_resolve_query_processing_hook_result(
|
||||
QueryProcessingResponse(
|
||||
query=None, rejection_message="Queries about X are not allowed."
|
||||
),
|
||||
"original query",
|
||||
)
|
||||
assert "Queries about X are not allowed." in str(exc_info.value)
|
||||
|
||||
|
||||
def test_fallback_rejection_message_when_none() -> None:
|
||||
"""No rejection_message → generic fallback used in OnyxError detail."""
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
_resolve_query_processing_hook_result(
|
||||
QueryProcessingResponse(query=None, rejection_message=None),
|
||||
"original query",
|
||||
)
|
||||
assert "No rejection reason was provided." in str(exc_info.value)
|
||||
|
||||
|
||||
def test_nonempty_query_rewrites_message_text() -> None:
|
||||
result = _resolve_query_processing_hook_result(
|
||||
QueryProcessingResponse(query="rewritten query"), "original query"
|
||||
)
|
||||
assert result == "rewritten query"
|
||||
|
||||
@@ -60,4 +60,4 @@ def test_non_rate_limit_error(mock_confluence_call: Mock) -> None:
|
||||
with pytest.raises(HTTPError):
|
||||
handled_call()
|
||||
|
||||
assert mock_confluence_call.call_count == 5
|
||||
assert mock_confluence_call.call_count == 1
|
||||
|
||||
@@ -7,7 +7,6 @@ from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
@@ -16,15 +15,13 @@ from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.executor import execute_hook
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.hooks.points.query_processing import QueryProcessingResponse
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
|
||||
# A valid QueryProcessingResponse payload — used by success-path tests.
|
||||
_RESPONSE_PAYLOAD: dict[str, Any] = {"query": "better test"}
|
||||
_RESPONSE_PAYLOAD: dict[str, Any] = {"rewritten_query": "better test"}
|
||||
|
||||
|
||||
def _make_hook(
|
||||
@@ -36,7 +33,6 @@ def _make_hook(
|
||||
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
|
||||
hook_id: int = 1,
|
||||
is_reachable: bool | None = None,
|
||||
hook_point: HookPoint = HookPoint.QUERY_PROCESSING,
|
||||
) -> MagicMock:
|
||||
hook = MagicMock()
|
||||
hook.is_active = is_active
|
||||
@@ -46,7 +42,6 @@ def _make_hook(
|
||||
hook.id = hook_id
|
||||
hook.fail_strategy = fail_strategy
|
||||
hook.is_reachable = is_reachable
|
||||
hook.hook_point = hook_point
|
||||
return hook
|
||||
|
||||
|
||||
@@ -145,7 +140,6 @@ def test_early_exit_returns_skipped_with_no_db_writes(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSkipped)
|
||||
@@ -158,9 +152,7 @@ def test_early_exit_returns_skipped_with_no_db_writes(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_success_returns_validated_model_and_sets_reachable(
|
||||
db_session: MagicMock,
|
||||
) -> None:
|
||||
def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> None:
|
||||
hook = _make_hook()
|
||||
|
||||
with (
|
||||
@@ -179,11 +171,9 @@ def test_success_returns_validated_model_and_sets_reachable(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
assert isinstance(result, QueryProcessingResponse)
|
||||
assert result.query == _RESPONSE_PAYLOAD["query"]
|
||||
assert result == _RESPONSE_PAYLOAD
|
||||
_, update_kwargs = mock_update.call_args
|
||||
assert update_kwargs["is_reachable"] is True
|
||||
mock_log.assert_not_called()
|
||||
@@ -210,11 +200,9 @@ def test_success_skips_reachable_write_when_already_true(db_session: MagicMock)
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
assert isinstance(result, QueryProcessingResponse)
|
||||
assert result.query == _RESPONSE_PAYLOAD["query"]
|
||||
assert result == _RESPONSE_PAYLOAD
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
@@ -242,7 +230,6 @@ def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
@@ -278,7 +265,6 @@ def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
@@ -402,7 +388,6 @@ def test_http_failure_paths(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
|
||||
else:
|
||||
@@ -410,7 +395,6 @@ def test_http_failure_paths(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
assert isinstance(result, expected_type)
|
||||
|
||||
@@ -458,7 +442,6 @@ def test_authorization_header(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
_, call_kwargs = mock_client.post.call_args
|
||||
@@ -474,16 +457,16 @@ def test_authorization_header(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"http_exception,expect_onyx_error",
|
||||
"http_exception,expected_result",
|
||||
[
|
||||
pytest.param(None, False, id="success_path"),
|
||||
pytest.param(httpx.ConnectError("refused"), True, id="hard_fail_path"),
|
||||
pytest.param(None, _RESPONSE_PAYLOAD, id="success_path"),
|
||||
pytest.param(httpx.ConnectError("refused"), OnyxError, id="hard_fail_path"),
|
||||
],
|
||||
)
|
||||
def test_persist_session_failure_is_swallowed(
|
||||
db_session: MagicMock,
|
||||
http_exception: Exception | None,
|
||||
expect_onyx_error: bool,
|
||||
expected_result: Any,
|
||||
) -> None:
|
||||
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
|
||||
@@ -506,13 +489,12 @@ def test_persist_session_failure_is_swallowed(
|
||||
side_effect=http_exception,
|
||||
)
|
||||
|
||||
if expect_onyx_error:
|
||||
if expected_result is OnyxError:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
|
||||
else:
|
||||
@@ -520,131 +502,8 @@ def test_persist_session_failure_is_swallowed(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
assert isinstance(result, QueryProcessingResponse)
|
||||
assert result.query == _RESPONSE_PAYLOAD["query"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response model validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StrictResponse(BaseModel):
|
||||
"""Strict model used to reliably trigger a ValidationError in tests."""
|
||||
|
||||
required_field: str # no default → missing key raises ValidationError
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fail_strategy,expected_type",
|
||||
[
|
||||
pytest.param(
|
||||
HookFailStrategy.SOFT, HookSoftFailed, id="validation_failure_soft"
|
||||
),
|
||||
pytest.param(HookFailStrategy.HARD, OnyxError, id="validation_failure_hard"),
|
||||
],
|
||||
)
|
||||
def test_response_validation_failure_respects_fail_strategy(
|
||||
db_session: MagicMock,
|
||||
fail_strategy: HookFailStrategy,
|
||||
expected_type: type,
|
||||
) -> None:
|
||||
"""A response that fails response_model validation is treated like any other
|
||||
hook failure: logged, is_reachable left unchanged, fail_strategy respected."""
|
||||
hook = _make_hook(fail_strategy=fail_strategy)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
# Response payload is missing required_field → ValidationError
|
||||
_setup_client(mock_client_cls, response=_make_response(json_return={}))
|
||||
|
||||
if expected_type is OnyxError:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=_StrictResponse,
|
||||
)
|
||||
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
|
||||
else:
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=_StrictResponse,
|
||||
)
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
|
||||
# is_reachable must not be updated — server responded correctly
|
||||
mock_update.assert_not_called()
|
||||
# failure must be logged
|
||||
mock_log.assert_called_once()
|
||||
_, log_kwargs = mock_log.call_args
|
||||
assert log_kwargs["is_success"] is False
|
||||
assert "validation" in (log_kwargs["error_message"] or "").lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Outer soft-fail guard in execute_hook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fail_strategy,expected_type",
|
||||
[
|
||||
pytest.param(HookFailStrategy.SOFT, HookSoftFailed, id="unexpected_exc_soft"),
|
||||
pytest.param(HookFailStrategy.HARD, ValueError, id="unexpected_exc_hard"),
|
||||
],
|
||||
)
|
||||
def test_unexpected_exception_in_inner_respects_fail_strategy(
|
||||
db_session: MagicMock,
|
||||
fail_strategy: HookFailStrategy,
|
||||
expected_type: type,
|
||||
) -> None:
|
||||
"""An unexpected exception raised by _execute_hook_inner (not an OnyxError from
|
||||
HARD fail — e.g. a bug or an assertion error) must be swallowed and return
|
||||
HookSoftFailed for SOFT strategy, or re-raised for HARD strategy."""
|
||||
hook = _make_hook(fail_strategy=fail_strategy)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch(
|
||||
"onyx.hooks.executor._execute_hook_inner",
|
||||
side_effect=ValueError("unexpected bug"),
|
||||
),
|
||||
):
|
||||
if expected_type is HookSoftFailed:
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
else:
|
||||
with pytest.raises(ValueError, match="unexpected bug"):
|
||||
execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:
|
||||
@@ -676,7 +535,6 @@ def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> Non
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=_PAYLOAD,
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
assert isinstance(result, HookSoftFailed)
|
||||
|
||||
@@ -1,153 +0,0 @@
|
||||
"""Tests for generic Celery task lifecycle Prometheus metrics."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.celery_task_metrics import _task_start_times
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_COMPLETED
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_DURATION
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_STARTED
|
||||
from onyx.server.metrics.celery_task_metrics import TASKS_ACTIVE
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_metrics() -> Iterator[None]:
|
||||
"""Clear metric state between tests."""
|
||||
_task_start_times.clear()
|
||||
yield
|
||||
_task_start_times.clear()
|
||||
|
||||
|
||||
def _make_task(name: str = "test_task", queue: str = "test_queue") -> MagicMock:
|
||||
task = MagicMock()
|
||||
task.name = name
|
||||
task.request = MagicMock()
|
||||
task.request.delivery_info = {"routing_key": queue}
|
||||
return task
|
||||
|
||||
|
||||
class TestCeleryTaskPrerun:
|
||||
def test_increments_started_and_active(self) -> None:
|
||||
task = _make_task()
|
||||
before_started = TASK_STARTED.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
before_active = TASKS_ACTIVE.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
after_started = TASK_STARTED.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
after_active = TASKS_ACTIVE.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
|
||||
assert after_started == before_started + 1
|
||||
assert after_active == before_active + 1
|
||||
|
||||
def test_records_start_time(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
assert "task-1" in _task_start_times
|
||||
|
||||
def test_noop_when_task_is_none(self) -> None:
|
||||
on_celery_task_prerun("task-1", None)
|
||||
assert "task-1" not in _task_start_times
|
||||
|
||||
def test_noop_when_task_id_is_none(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun(None, task)
|
||||
# Should not crash
|
||||
|
||||
def test_handles_missing_delivery_info(self) -> None:
|
||||
task = _make_task()
|
||||
task.request.delivery_info = None
|
||||
on_celery_task_prerun("task-1", task)
|
||||
assert "task-1" in _task_start_times
|
||||
|
||||
|
||||
class TestCeleryTaskPostrun:
|
||||
def test_increments_completed_success(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
before = TASK_COMPLETED.labels(
|
||||
task_name="test_task", queue="test_queue", outcome="success"
|
||||
)._value.get()
|
||||
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
|
||||
after = TASK_COMPLETED.labels(
|
||||
task_name="test_task", queue="test_queue", outcome="success"
|
||||
)._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_increments_completed_failure(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
before = TASK_COMPLETED.labels(
|
||||
task_name="test_task", queue="test_queue", outcome="failure"
|
||||
)._value.get()
|
||||
|
||||
on_celery_task_postrun("task-1", task, "FAILURE")
|
||||
|
||||
after = TASK_COMPLETED.labels(
|
||||
task_name="test_task", queue="test_queue", outcome="failure"
|
||||
)._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_decrements_active(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
active_before = TASKS_ACTIVE.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
|
||||
active_after = TASKS_ACTIVE.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
assert active_after == active_before - 1
|
||||
|
||||
def test_observes_duration(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
before_count = TASK_DURATION.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._sum.get()
|
||||
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
|
||||
after_count = TASK_DURATION.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._sum.get()
|
||||
# Duration should have increased (at least slightly)
|
||||
assert after_count > before_count
|
||||
|
||||
def test_cleans_up_start_time(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
assert "task-1" in _task_start_times
|
||||
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
assert "task-1" not in _task_start_times
|
||||
|
||||
def test_noop_when_task_is_none(self) -> None:
|
||||
on_celery_task_postrun("task-1", None, "SUCCESS")
|
||||
|
||||
def test_handles_missing_start_time(self) -> None:
|
||||
"""Postrun without prerun should not crash."""
|
||||
task = _make_task()
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
# Should not raise
|
||||
@@ -1,359 +0,0 @@
|
||||
"""Tests for indexing pipeline Prometheus collectors."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
|
||||
|
||||
class TestQueueDepthCollector:
|
||||
def test_returns_empty_when_factory_not_set(self) -> None:
|
||||
collector = QueueDepthCollector()
|
||||
assert collector.collect() == []
|
||||
|
||||
def test_returns_empty_describe(self) -> None:
|
||||
collector = QueueDepthCollector()
|
||||
assert collector.describe() == []
|
||||
|
||||
def test_collects_queue_depths(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=5,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value={"task-1", "task-2"},
|
||||
),
|
||||
):
|
||||
families = collector.collect()
|
||||
|
||||
assert len(families) == 3
|
||||
depth_family = families[0]
|
||||
unacked_family = families[1]
|
||||
age_family = families[2]
|
||||
|
||||
assert depth_family.name == "onyx_queue_depth"
|
||||
assert len(depth_family.samples) > 0
|
||||
for sample in depth_family.samples:
|
||||
assert sample.value == 5
|
||||
|
||||
assert unacked_family.name == "onyx_queue_unacked"
|
||||
unacked_labels = {s.labels["queue"] for s in unacked_family.samples}
|
||||
assert "docfetching" in unacked_labels
|
||||
assert "docprocessing" in unacked_labels
|
||||
|
||||
assert age_family.name == "onyx_queue_oldest_task_age_seconds"
|
||||
for sample in unacked_family.samples:
|
||||
assert sample.value == 2
|
||||
|
||||
def test_handles_redis_error_gracefully(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
side_effect=Exception("connection lost"),
|
||||
):
|
||||
families = collector.collect()
|
||||
|
||||
# Returns stale cache (empty on first call)
|
||||
assert families == []
|
||||
|
||||
def test_caching_returns_stale_within_ttl(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=60)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=5,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
first = collector.collect()
|
||||
|
||||
# Second call within TTL should return cached result without calling Redis
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
side_effect=Exception("should not be called"),
|
||||
):
|
||||
second = collector.collect()
|
||||
|
||||
assert first is second # Same object, from cache
|
||||
|
||||
def test_factory_called_each_scrape(self) -> None:
|
||||
"""Verify the Redis factory is called on each fresh collect, not cached."""
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
factory = MagicMock(return_value=MagicMock())
|
||||
collector.set_redis_factory(factory)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=0,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
collector.collect()
|
||||
collector.collect()
|
||||
|
||||
assert factory.call_count == 2
|
||||
|
||||
def test_error_returns_stale_cache(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
# First call succeeds
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=10,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
good_result = collector.collect()
|
||||
|
||||
assert len(good_result) == 3
|
||||
assert good_result[0].samples[0].value == 10
|
||||
|
||||
# Second call fails — should return stale cache, not empty
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
side_effect=Exception("Redis down"),
|
||||
):
|
||||
stale_result = collector.collect()
|
||||
|
||||
assert stale_result is good_result
|
||||
|
||||
|
||||
class TestIndexAttemptCollector:
|
||||
def test_returns_empty_when_not_configured(self) -> None:
|
||||
collector = IndexAttemptCollector()
|
||||
assert collector.collect() == []
|
||||
|
||||
def test_returns_empty_describe(self) -> None:
|
||||
collector = IndexAttemptCollector()
|
||||
assert collector.describe() == []
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_collects_index_attempts(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = IndexAttemptCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = ["public"]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
from onyx.db.enums import IndexingStatus
|
||||
|
||||
mock_row = (
|
||||
IndexingStatus.IN_PROGRESS,
|
||||
MagicMock(value="web"),
|
||||
81,
|
||||
"Table Tennis Blade Guide",
|
||||
2,
|
||||
)
|
||||
mock_session.query.return_value.join.return_value.join.return_value.filter.return_value.group_by.return_value.all.return_value = [
|
||||
mock_row
|
||||
]
|
||||
|
||||
families = collector.collect()
|
||||
assert len(families) == 1
|
||||
assert families[0].name == "onyx_index_attempts_active"
|
||||
assert len(families[0].samples) == 1
|
||||
sample = families[0].samples[0]
|
||||
assert sample.labels == {
|
||||
"status": "in_progress",
|
||||
"source": "web",
|
||||
"tenant_id": "public",
|
||||
"connector_name": "Table Tennis Blade Guide",
|
||||
"cc_pair_id": "81",
|
||||
}
|
||||
assert sample.value == 2
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
def test_handles_db_error_gracefully(
|
||||
self,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = IndexAttemptCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.side_effect = Exception("DB down")
|
||||
families = collector.collect()
|
||||
# No stale cache, so returns empty
|
||||
assert families == []
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
def test_skips_none_tenant_ids(
|
||||
self,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = IndexAttemptCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = [None]
|
||||
families = collector.collect()
|
||||
assert len(families) == 1 # Returns the gauge family, just with no samples
|
||||
assert len(families[0].samples) == 0
|
||||
|
||||
|
||||
class TestConnectorHealthCollector:
|
||||
def test_returns_empty_when_not_configured(self) -> None:
|
||||
collector = ConnectorHealthCollector()
|
||||
assert collector.collect() == []
|
||||
|
||||
def test_returns_empty_describe(self) -> None:
|
||||
collector = ConnectorHealthCollector()
|
||||
assert collector.describe() == []
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_collects_connector_health(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = ConnectorHealthCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = ["public"]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
last_success = now - timedelta(hours=2)
|
||||
|
||||
mock_status = MagicMock(value="ACTIVE")
|
||||
mock_source = MagicMock(value="google_drive")
|
||||
# Row: (id, status, in_error, last_success, name, source)
|
||||
mock_row = (
|
||||
42,
|
||||
mock_status,
|
||||
True, # in_repeated_error_state
|
||||
last_success,
|
||||
"My GDrive Connector",
|
||||
mock_source,
|
||||
)
|
||||
mock_session.query.return_value.join.return_value.all.return_value = [mock_row]
|
||||
|
||||
# Mock the index attempt queries (error counts + docs counts)
|
||||
mock_session.query.return_value.filter.return_value.group_by.return_value.all.return_value = (
|
||||
[]
|
||||
)
|
||||
|
||||
families = collector.collect()
|
||||
|
||||
assert len(families) == 6
|
||||
names = {f.name for f in families}
|
||||
assert names == {
|
||||
"onyx_connector_last_success_age_seconds",
|
||||
"onyx_connector_in_error_state",
|
||||
"onyx_connectors_by_status",
|
||||
"onyx_connectors_in_error_total",
|
||||
"onyx_connector_docs_indexed",
|
||||
"onyx_connector_error_count",
|
||||
}
|
||||
|
||||
staleness = next(
|
||||
f for f in families if f.name == "onyx_connector_last_success_age_seconds"
|
||||
)
|
||||
assert len(staleness.samples) == 1
|
||||
assert staleness.samples[0].value == pytest.approx(7200, abs=5)
|
||||
|
||||
error_state = next(
|
||||
f for f in families if f.name == "onyx_connector_in_error_state"
|
||||
)
|
||||
assert error_state.samples[0].value == 1.0
|
||||
|
||||
by_status = next(f for f in families if f.name == "onyx_connectors_by_status")
|
||||
assert by_status.samples[0].labels == {
|
||||
"tenant_id": "public",
|
||||
"status": "ACTIVE",
|
||||
}
|
||||
assert by_status.samples[0].value == 1
|
||||
|
||||
error_total = next(
|
||||
f for f in families if f.name == "onyx_connectors_in_error_total"
|
||||
)
|
||||
assert error_total.samples[0].value == 1
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_skips_staleness_when_no_last_success(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = ConnectorHealthCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = ["public"]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_status = MagicMock(value="INITIAL_INDEXING")
|
||||
mock_source = MagicMock(value="slack")
|
||||
mock_row = (
|
||||
10,
|
||||
mock_status,
|
||||
False,
|
||||
None, # no last_successful_index_time
|
||||
0,
|
||||
mock_source,
|
||||
)
|
||||
mock_session.query.return_value.join.return_value.all.return_value = [mock_row]
|
||||
|
||||
families = collector.collect()
|
||||
|
||||
staleness = next(
|
||||
f for f in families if f.name == "onyx_connector_last_success_age_seconds"
|
||||
)
|
||||
assert len(staleness.samples) == 0
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
def test_handles_db_error_gracefully(
|
||||
self,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = ConnectorHealthCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.side_effect = Exception("DB down")
|
||||
families = collector.collect()
|
||||
assert families == []
|
||||
@@ -1,96 +0,0 @@
|
||||
"""Tests for indexing pipeline setup (Redis factory caching)."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
|
||||
|
||||
|
||||
def _make_mock_app(client: MagicMock) -> MagicMock:
|
||||
"""Create a mock Celery app whose broker_connection().channel().client
|
||||
returns the given client."""
|
||||
mock_app = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.channel.return_value.client = client
|
||||
|
||||
mock_app.broker_connection.return_value = mock_conn
|
||||
|
||||
return mock_app
|
||||
|
||||
|
||||
class TestMakeBrokerRedisFactory:
|
||||
def test_caches_redis_client_across_calls(self) -> None:
|
||||
"""Factory should reuse the same client on subsequent calls."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
client1 = factory()
|
||||
client2 = factory()
|
||||
|
||||
assert client1 is client2
|
||||
# broker_connection should only be called once
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
def test_reconnects_when_ping_fails(self) -> None:
|
||||
"""Factory should create a new client if ping fails (stale connection)."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
client1 = factory()
|
||||
assert client1 is mock_client_stale
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
# Switch to fresh client for next connection
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails on stale, reconnects
|
||||
client2 = factory()
|
||||
assert client2 is mock_client_fresh
|
||||
assert mock_app.broker_connection.call_count == 2
|
||||
|
||||
def test_reconnect_closes_stale_client(self) -> None:
|
||||
"""When ping fails, the old client should be closed before reconnecting."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
factory()
|
||||
|
||||
# Switch to fresh client
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails, should close stale client
|
||||
factory()
|
||||
mock_client_stale.close.assert_called_once()
|
||||
|
||||
def test_first_call_creates_connection(self) -> None:
|
||||
"""First call should always create a new connection."""
|
||||
mock_client = MagicMock()
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
client = factory()
|
||||
|
||||
assert client is mock_client
|
||||
mock_app.broker_connection.assert_called_once()
|
||||
@@ -1,335 +0,0 @@
|
||||
"""Tests for per-connector indexing task Prometheus metrics."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.indexing_task_metrics import _connector_cache
|
||||
from onyx.server.metrics.indexing_task_metrics import _indexing_start_times
|
||||
from onyx.server.metrics.indexing_task_metrics import ConnectorInfo
|
||||
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_COMPLETED
|
||||
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_DURATION
|
||||
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_STARTED
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state() -> Iterator[None]:
|
||||
"""Clear caches and state between tests.
|
||||
|
||||
Sets CURRENT_TENANT_ID_CONTEXTVAR to a realistic value so cache keys
|
||||
are never keyed on an empty string.
|
||||
"""
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test_tenant")
|
||||
_connector_cache.clear()
|
||||
_indexing_start_times.clear()
|
||||
yield
|
||||
_connector_cache.clear()
|
||||
_indexing_start_times.clear()
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _make_task(name: str) -> MagicMock:
|
||||
task = MagicMock()
|
||||
task.name = name
|
||||
return task
|
||||
|
||||
|
||||
def _mock_db_lookup(
|
||||
source: str = "google_drive", name: str = "My Google Drive"
|
||||
) -> tuple:
|
||||
"""Return (session_patch, cc_pair_patch) context managers for DB mocking."""
|
||||
mock_cc_pair = MagicMock()
|
||||
mock_cc_pair.name = name
|
||||
mock_cc_pair.connector.source.value = source
|
||||
|
||||
session_patch = patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
cc_pair_patch = patch(
|
||||
"onyx.db.connector_credential_pair.get_connector_credential_pair_from_id",
|
||||
return_value=mock_cc_pair,
|
||||
)
|
||||
return session_patch, cc_pair_patch
|
||||
|
||||
|
||||
class TestIndexingTaskPrerun:
|
||||
def test_skips_non_indexing_task(self) -> None:
|
||||
task = _make_task("some_other_task")
|
||||
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
assert "task-1" not in _indexing_start_times
|
||||
|
||||
def test_emits_started_for_docfetching(self) -> None:
|
||||
# Pre-populate cache to avoid DB lookup (tenant-scoped key)
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="google_drive", name="My Google Drive"
|
||||
)
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "tenant-1"}
|
||||
|
||||
before = INDEXING_TASK_STARTED.labels(
|
||||
task_name="connector_doc_fetching_task",
|
||||
source="google_drive",
|
||||
tenant_id="tenant-1",
|
||||
cc_pair_id="42",
|
||||
)._value.get()
|
||||
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
|
||||
after = INDEXING_TASK_STARTED.labels(
|
||||
task_name="connector_doc_fetching_task",
|
||||
source="google_drive",
|
||||
tenant_id="tenant-1",
|
||||
cc_pair_id="42",
|
||||
)._value.get()
|
||||
|
||||
assert after == before + 1
|
||||
assert "task-1" in _indexing_start_times
|
||||
|
||||
def test_emits_started_for_docprocessing(self) -> None:
|
||||
_connector_cache[("test_tenant", 10)] = ConnectorInfo(
|
||||
source="slack", name="Slack Connector"
|
||||
)
|
||||
|
||||
task = _make_task("docprocessing_task")
|
||||
kwargs = {"cc_pair_id": 10, "tenant_id": "public"}
|
||||
|
||||
on_indexing_task_prerun("task-2", task, kwargs)
|
||||
assert "task-2" in _indexing_start_times
|
||||
|
||||
def test_cache_hit_avoids_db_call(self) -> None:
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="confluence", name="Engineering Confluence"
|
||||
)
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
|
||||
|
||||
# No DB patches needed — cache should be used
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
assert "task-1" in _indexing_start_times
|
||||
|
||||
def test_db_lookup_on_cache_miss(self) -> None:
|
||||
"""On first encounter of a cc_pair_id, does a DB lookup and caches."""
|
||||
mock_cc_pair = MagicMock()
|
||||
mock_cc_pair.name = "Notion Workspace"
|
||||
mock_cc_pair.connector.source.value = "notion"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=MagicMock())
|
||||
mock_session.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_task_metrics._resolve_connector"
|
||||
) as mock_resolve,
|
||||
):
|
||||
mock_resolve.return_value = ConnectorInfo(
|
||||
source="notion", name="Notion Workspace"
|
||||
)
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 77, "tenant_id": "public"}
|
||||
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
mock_resolve.assert_called_once_with(77)
|
||||
|
||||
def test_missing_cc_pair_returns_unknown(self) -> None:
|
||||
"""When _resolve_connector can't find the cc_pair, uses 'unknown'."""
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_task_metrics._resolve_connector"
|
||||
) as mock_resolve:
|
||||
mock_resolve.return_value = ConnectorInfo(source="unknown", name="unknown")
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 999, "tenant_id": "public"}
|
||||
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
assert "task-1" in _indexing_start_times
|
||||
|
||||
def test_skips_when_cc_pair_id_missing(self) -> None:
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"tenant_id": "public"}
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
assert "task-1" not in _indexing_start_times
|
||||
|
||||
def test_db_error_does_not_crash(self) -> None:
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_task_metrics._resolve_connector",
|
||||
side_effect=Exception("DB down"),
|
||||
):
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
|
||||
# Should not raise
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
|
||||
|
||||
class TestIndexingTaskPostrun:
|
||||
def test_skips_non_indexing_task(self) -> None:
|
||||
task = _make_task("some_other_task")
|
||||
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
|
||||
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
|
||||
# Should not raise
|
||||
|
||||
def test_emits_completed_and_duration(self) -> None:
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="google_drive", name="Marketing Drive"
|
||||
)
|
||||
|
||||
task = _make_task("docprocessing_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
|
||||
|
||||
# Simulate prerun
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
|
||||
before_completed = INDEXING_TASK_COMPLETED.labels(
|
||||
task_name="docprocessing_task",
|
||||
source="google_drive",
|
||||
tenant_id="public",
|
||||
cc_pair_id="42",
|
||||
outcome="success",
|
||||
)._value.get()
|
||||
|
||||
before_duration = INDEXING_TASK_DURATION.labels(
|
||||
task_name="docprocessing_task",
|
||||
source="google_drive",
|
||||
tenant_id="public",
|
||||
)._sum.get()
|
||||
|
||||
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
|
||||
|
||||
after_completed = INDEXING_TASK_COMPLETED.labels(
|
||||
task_name="docprocessing_task",
|
||||
source="google_drive",
|
||||
tenant_id="public",
|
||||
cc_pair_id="42",
|
||||
outcome="success",
|
||||
)._value.get()
|
||||
|
||||
after_duration = INDEXING_TASK_DURATION.labels(
|
||||
task_name="docprocessing_task",
|
||||
source="google_drive",
|
||||
tenant_id="public",
|
||||
)._sum.get()
|
||||
|
||||
assert after_completed == before_completed + 1
|
||||
assert after_duration > before_duration
|
||||
|
||||
def test_failure_outcome(self) -> None:
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="slack", name="Slack"
|
||||
)
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
|
||||
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
|
||||
before = INDEXING_TASK_COMPLETED.labels(
|
||||
task_name="connector_doc_fetching_task",
|
||||
source="slack",
|
||||
tenant_id="public",
|
||||
cc_pair_id="42",
|
||||
outcome="failure",
|
||||
)._value.get()
|
||||
|
||||
on_indexing_task_postrun("task-1", task, kwargs, "FAILURE")
|
||||
|
||||
after = INDEXING_TASK_COMPLETED.labels(
|
||||
task_name="connector_doc_fetching_task",
|
||||
source="slack",
|
||||
tenant_id="public",
|
||||
cc_pair_id="42",
|
||||
outcome="failure",
|
||||
)._value.get()
|
||||
|
||||
assert after == before + 1
|
||||
|
||||
def test_handles_postrun_without_prerun(self) -> None:
|
||||
"""Postrun for an indexing task without a matching prerun should not crash."""
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="slack", name="Slack"
|
||||
)
|
||||
|
||||
task = _make_task("docprocessing_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
|
||||
|
||||
# No prerun — should still emit completed counter, just skip duration
|
||||
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
|
||||
|
||||
|
||||
class TestResolveConnector:
|
||||
def test_failed_lookup_not_cached(self) -> None:
|
||||
"""When DB lookup returns None, result should NOT be cached."""
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
|
||||
try:
|
||||
with (
|
||||
patch("onyx.db.engine.sql_engine.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.db.connector_credential_pair"
|
||||
".get_connector_credential_pair_from_id",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
|
||||
|
||||
result = _resolve_connector(999)
|
||||
assert result.source == "unknown"
|
||||
# Should NOT be cached so subsequent calls can retry
|
||||
assert ("test-tenant", 999) not in _connector_cache
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def test_exception_not_cached(self) -> None:
|
||||
"""When DB lookup raises, result should NOT be cached."""
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"onyx.db.engine.sql_engine.get_session_with_current_tenant",
|
||||
side_effect=Exception("DB down"),
|
||||
),
|
||||
):
|
||||
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
|
||||
|
||||
result = _resolve_connector(888)
|
||||
assert result.source == "unknown"
|
||||
assert ("test-tenant", 888) not in _connector_cache
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def test_successful_lookup_is_cached(self) -> None:
|
||||
"""When DB lookup succeeds, result should be cached."""
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
|
||||
try:
|
||||
mock_cc_pair = MagicMock()
|
||||
mock_cc_pair.name = "My Drive"
|
||||
mock_cc_pair.connector.source.value = "google_drive"
|
||||
|
||||
with (
|
||||
patch("onyx.db.engine.sql_engine.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.db.connector_credential_pair"
|
||||
".get_connector_credential_pair_from_id",
|
||||
return_value=mock_cc_pair,
|
||||
),
|
||||
):
|
||||
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
|
||||
|
||||
result = _resolve_connector(777)
|
||||
assert result.source == "google_drive"
|
||||
assert result.name == "My Drive"
|
||||
assert ("test-tenant", 777) in _connector_cache
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
@@ -1,69 +0,0 @@
|
||||
"""Tests for the Prometheus metrics server module."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.metrics_server import _DEFAULT_PORTS
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_server_state() -> Iterator[None]:
|
||||
"""Reset the global _server_started between tests."""
|
||||
import onyx.server.metrics.metrics_server as mod
|
||||
|
||||
mod._server_started = False
|
||||
yield
|
||||
mod._server_started = False
|
||||
|
||||
|
||||
class TestStartMetricsServer:
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
def test_uses_default_port_for_known_worker(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port == _DEFAULT_PORTS["monitoring"]
|
||||
mock_start.assert_called_once_with(_DEFAULT_PORTS["monitoring"])
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
@patch.dict("os.environ", {"PROMETHEUS_METRICS_PORT": "9999"})
|
||||
def test_env_var_overrides_default(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port == 9999
|
||||
mock_start.assert_called_once_with(9999)
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
@patch.dict("os.environ", {"PROMETHEUS_METRICS_ENABLED": "false"})
|
||||
def test_disabled_via_env_var(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port is None
|
||||
mock_start.assert_not_called()
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
def test_unknown_worker_type_no_env_var(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("unknown_worker")
|
||||
assert port is None
|
||||
mock_start.assert_not_called()
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
def test_idempotent(self, mock_start: MagicMock) -> None:
|
||||
port1 = start_metrics_server("monitoring")
|
||||
port2 = start_metrics_server("monitoring")
|
||||
assert port1 == _DEFAULT_PORTS["monitoring"]
|
||||
assert port2 is None
|
||||
mock_start.assert_called_once()
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
def test_handles_os_error(self, mock_start: MagicMock) -> None:
|
||||
mock_start.side_effect = OSError("Address already in use")
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port is None
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
@patch.dict("os.environ", {"PROMETHEUS_METRICS_PORT": "not_a_number"})
|
||||
def test_invalid_port_env_var_returns_none(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port is None
|
||||
mock_start.assert_not_called()
|
||||
@@ -23,12 +23,6 @@ upstream web_server {
|
||||
# Conditionally include MCP upstream configuration
|
||||
include /etc/nginx/conf.d/mcp_upstream.conf.inc;
|
||||
|
||||
# WebSocket support: only set Connection "upgrade" for actual upgrade requests
|
||||
map $http_upgrade $connection_upgrade {
|
||||
default upgrade;
|
||||
'' close;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80 default_server;
|
||||
|
||||
@@ -52,10 +46,8 @@ server {
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
|
||||
# need to use 1.1 to support chunked transfers and WebSocket
|
||||
# need to use 1.1 to support chunked transfers
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_buffering off;
|
||||
|
||||
# timeout settings
|
||||
|
||||
@@ -23,12 +23,6 @@ upstream web_server {
|
||||
# Conditionally include MCP upstream configuration
|
||||
include /etc/nginx/conf.d/mcp_upstream.conf.inc;
|
||||
|
||||
# WebSocket support: only set Connection "upgrade" for actual upgrade requests
|
||||
map $http_upgrade $connection_upgrade {
|
||||
default upgrade;
|
||||
'' close;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80 default_server;
|
||||
|
||||
@@ -53,10 +47,8 @@ server {
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
|
||||
# need to use 1.1 to support chunked transfers and WebSocket
|
||||
# need to use 1.1 to support chunked transfers
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_buffering off;
|
||||
|
||||
# we don't want nginx trying to do something clever with
|
||||
@@ -100,8 +92,6 @@ server {
|
||||
proxy_set_header Host $host;
|
||||
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_buffering off;
|
||||
# we don't want nginx trying to do something clever with
|
||||
# redirects, we set the Host: header above already.
|
||||
|
||||
@@ -23,12 +23,6 @@ upstream web_server {
|
||||
# Conditionally include MCP upstream configuration
|
||||
include /etc/nginx/conf.d/mcp_upstream.conf.inc;
|
||||
|
||||
# WebSocket support: only set Connection "upgrade" for actual upgrade requests
|
||||
map $http_upgrade $connection_upgrade {
|
||||
default upgrade;
|
||||
'' close;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80 default_server;
|
||||
|
||||
@@ -53,10 +47,8 @@ server {
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
|
||||
# need to use 1.1 to support chunked transfers and WebSocket
|
||||
# need to use 1.1 to support chunked transfers
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_buffering off;
|
||||
|
||||
# timeout settings
|
||||
@@ -114,8 +106,6 @@ server {
|
||||
proxy_set_header Host $host;
|
||||
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_buffering off;
|
||||
|
||||
# timeout settings
|
||||
|
||||
@@ -28,12 +28,6 @@ data:
|
||||
}
|
||||
{{- end }}
|
||||
|
||||
# WebSocket support: only set Connection "upgrade" for actual upgrade requests
|
||||
map $http_upgrade $connection_upgrade {
|
||||
default upgrade;
|
||||
'' close;
|
||||
}
|
||||
|
||||
server.conf: |
|
||||
server {
|
||||
listen 1024;
|
||||
@@ -71,8 +65,6 @@ data:
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
# timeout settings
|
||||
|
||||
@@ -10,7 +10,7 @@ data:
|
||||
#!/usr/bin/env sh
|
||||
set -eu
|
||||
|
||||
HOST="${PGINTO_HOST:-${POSTGRES_HOST:-localhost}}"
|
||||
HOST="${POSTGRES_HOST:-localhost}"
|
||||
PORT="${POSTGRES_PORT:-5432}"
|
||||
USER="${POSTGRES_USER:-postgres}"
|
||||
DB="${POSTGRES_DB:-postgres}"
|
||||
|
||||
@@ -103,7 +103,7 @@ opensearch:
|
||||
- name: OPENSEARCH_INITIAL_ADMIN_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-opensearch # Must match auth.opensearch.secretName or auth.opensearch.existingSecret if defined.
|
||||
name: onyx-opensearch # Must match auth.opensearch.secretName.
|
||||
key: opensearch_admin_password # Must match auth.opensearch.secretKeys value.
|
||||
|
||||
resources:
|
||||
@@ -282,7 +282,7 @@ nginx:
|
||||
# The ingress-nginx subchart doesn't auto-detect our custom ConfigMap changes.
|
||||
# Workaround: Helm upgrade will restart if the following annotation value changes.
|
||||
podAnnotations:
|
||||
onyx.app/nginx-config-version: "2"
|
||||
onyx.app/nginx-config-version: "1"
|
||||
|
||||
# Propagate DOMAIN into nginx so server_name continues to use the same env var
|
||||
extraEnvs:
|
||||
|
||||
@@ -83,14 +83,6 @@
|
||||
"scope": [],
|
||||
"rule": "Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import "@opal/components/tooltip.css";
|
||||
import {
|
||||
Disabled,
|
||||
Interactive,
|
||||
type InteractiveStatelessProps,
|
||||
} from "@opal/core";
|
||||
import { Interactive, type InteractiveStatelessProps } from "@opal/core";
|
||||
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
@@ -36,6 +32,9 @@ type ButtonProps = InteractiveStatelessProps &
|
||||
*/
|
||||
size?: ContainerSizeVariants;
|
||||
|
||||
/** HTML button type. When provided, Container renders a `<button>` element. */
|
||||
type?: "submit" | "button" | "reset";
|
||||
|
||||
/** Tooltip text shown on hover. */
|
||||
tooltip?: string;
|
||||
|
||||
@@ -44,9 +43,6 @@ type ButtonProps = InteractiveStatelessProps &
|
||||
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
|
||||
/** Wraps the button in a Disabled context. `false` overrides parent contexts. */
|
||||
disabled?: boolean;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -63,7 +59,6 @@ function Button({
|
||||
tooltip,
|
||||
tooltipSide = "top",
|
||||
responsiveHideText = false,
|
||||
disabled,
|
||||
...interactiveProps
|
||||
}: ButtonProps) {
|
||||
const isLarge = size === "lg";
|
||||
@@ -81,7 +76,7 @@ function Button({
|
||||
) : null;
|
||||
|
||||
const button = (
|
||||
<Interactive.Stateless type={type} {...interactiveProps}>
|
||||
<Interactive.Stateless {...interactiveProps}>
|
||||
<Interactive.Container
|
||||
type={type}
|
||||
border={interactiveProps.prominence === "secondary"}
|
||||
@@ -107,7 +102,9 @@ function Button({
|
||||
</Interactive.Stateless>
|
||||
);
|
||||
|
||||
const result = tooltip ? (
|
||||
if (!tooltip) return button;
|
||||
|
||||
return (
|
||||
<TooltipPrimitive.Root>
|
||||
<TooltipPrimitive.Trigger asChild>{button}</TooltipPrimitive.Trigger>
|
||||
<TooltipPrimitive.Portal>
|
||||
@@ -120,15 +117,7 @@ function Button({
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
) : (
|
||||
button
|
||||
);
|
||||
|
||||
if (disabled != null) {
|
||||
return <Disabled disabled={disabled}>{result}</Disabled>;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
export { Button, type ButtonProps };
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
.opal-button-chevron {
|
||||
transition: rotate 200ms ease;
|
||||
}
|
||||
|
||||
.interactive[data-interaction="hover"] .opal-button-chevron,
|
||||
.interactive[data-interaction="active"] .opal-button-chevron {
|
||||
rotate: -180deg;
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
import "@opal/components/buttons/chevron.css";
|
||||
import type { IconProps } from "@opal/types";
|
||||
import { SvgChevronDownSmall } from "@opal/icons";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
/**
|
||||
* Chevron icon that rotates 180° when its parent `.interactive` enters
|
||||
* hover / active state. Shared by OpenButton, FilterButton, and any
|
||||
* future button that needs an animated dropdown indicator.
|
||||
*
|
||||
* Stable component identity — never causes React to remount the SVG.
|
||||
*/
|
||||
function ChevronIcon({ className, ...props }: IconProps) {
|
||||
return (
|
||||
<SvgChevronDownSmall
|
||||
className={cn(className, "opal-button-chevron")}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export { ChevronIcon };
|
||||
@@ -1,107 +0,0 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { FilterButton } from "@opal/components";
|
||||
import { Disabled as DisabledProvider } from "@opal/core";
|
||||
import { SvgUser, SvgActions, SvgTag } from "@opal/icons";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
|
||||
const meta: Meta<typeof FilterButton> = {
|
||||
title: "opal/components/FilterButton",
|
||||
component: FilterButton,
|
||||
tags: ["autodocs"],
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<TooltipPrimitive.Provider>
|
||||
<Story />
|
||||
</TooltipPrimitive.Provider>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof FilterButton>;
|
||||
|
||||
export const Empty: Story = {
|
||||
args: {
|
||||
icon: SvgUser,
|
||||
children: "Everyone",
|
||||
},
|
||||
};
|
||||
|
||||
export const Active: Story = {
|
||||
args: {
|
||||
icon: SvgUser,
|
||||
active: true,
|
||||
children: "By alice@example.com",
|
||||
onClear: () => console.log("clear"),
|
||||
},
|
||||
};
|
||||
|
||||
export const Open: Story = {
|
||||
args: {
|
||||
icon: SvgActions,
|
||||
interaction: "hover",
|
||||
children: "All Actions",
|
||||
},
|
||||
};
|
||||
|
||||
export const ActiveOpen: Story = {
|
||||
args: {
|
||||
icon: SvgActions,
|
||||
active: true,
|
||||
interaction: "hover",
|
||||
children: "2 selected",
|
||||
onClear: () => console.log("clear"),
|
||||
},
|
||||
};
|
||||
|
||||
export const Disabled: Story = {
|
||||
args: {
|
||||
icon: SvgTag,
|
||||
children: "All Tags",
|
||||
},
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<DisabledProvider disabled>
|
||||
<Story />
|
||||
</DisabledProvider>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export const DisabledActive: Story = {
|
||||
args: {
|
||||
icon: SvgTag,
|
||||
active: true,
|
||||
children: "2 tags",
|
||||
onClear: () => console.log("clear"),
|
||||
},
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<DisabledProvider disabled>
|
||||
<Story />
|
||||
</DisabledProvider>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export const StateComparison: Story = {
|
||||
render: () => (
|
||||
<div style={{ display: "flex", gap: 12, alignItems: "center" }}>
|
||||
<FilterButton icon={SvgUser} onClear={() => undefined}>
|
||||
Everyone
|
||||
</FilterButton>
|
||||
<FilterButton icon={SvgUser} active onClear={() => console.log("clear")}>
|
||||
By alice@example.com
|
||||
</FilterButton>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
export const WithTooltip: Story = {
|
||||
args: {
|
||||
icon: SvgUser,
|
||||
children: "Everyone",
|
||||
tooltip: "Filter by creator",
|
||||
tooltipSide: "bottom",
|
||||
},
|
||||
};
|
||||
@@ -1,70 +0,0 @@
|
||||
# FilterButton
|
||||
|
||||
**Import:** `import { FilterButton, type FilterButtonProps } from "@opal/components";`
|
||||
|
||||
A stateful filter trigger with a built-in chevron (when empty) and a clear button (when selected). Hardcodes `variant="select-filter"` and delegates to `Interactive.Stateful`, adding automatic open-state detection from Radix `data-state`. Designed to sit inside a `Popover.Trigger` for filter dropdowns.
|
||||
|
||||
## Relationship to OpenButton
|
||||
|
||||
FilterButton shares a similar call stack to `OpenButton`:
|
||||
|
||||
```
|
||||
Interactive.Stateful → Interactive.Container → content row (icon + label + trailing indicator)
|
||||
```
|
||||
|
||||
FilterButton is a **narrower, filter-specific** variant:
|
||||
|
||||
- It hardcodes `variant="select-filter"` (OpenButton uses `"select-heavy"`)
|
||||
- It exposes `active?: boolean` instead of the raw `state` prop (maps to `"selected"` / `"empty"` internally)
|
||||
- When active, the chevron is hidden via `visibility` and an absolutely-positioned clear `Button` with `prominence="tertiary"` overlays it — placed as a sibling outside the `<button>` to avoid nesting buttons
|
||||
- It uses the shared `ChevronIcon` from `buttons/chevron` (same as OpenButton)
|
||||
- It does not support `foldable`, `size`, or `width` — it is always `"lg"`
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
div.relative <- bounding wrapper
|
||||
Interactive.Stateful <- variant="select-filter", interaction, state
|
||||
└─ Interactive.Container (button) <- height="lg", default rounding/padding
|
||||
└─ div.interactive-foreground
|
||||
├─ div > Icon (interactive-foreground-icon)
|
||||
├─ <span> label text
|
||||
└─ ChevronIcon (when empty)
|
||||
OR spacer div (when selected — reserves chevron space)
|
||||
div.absolute <- clear Button overlay (when selected)
|
||||
└─ Button (SvgX, size="2xs", prominence="tertiary")
|
||||
```
|
||||
|
||||
- **Open-state detection** reads `data-state="open"` injected by Radix triggers (e.g. `Popover.Trigger`), falling back to the explicit `interaction` prop.
|
||||
- **Chevron rotation** uses the shared `ChevronIcon` component and `buttons/chevron.css`, which rotates 180deg when `data-interaction="hover"`.
|
||||
- **Clear button** is absolutely positioned outside the `<button>` element tree to avoid invalid nested `<button>` elements. An invisible spacer inside the button reserves the same space so layout doesn't shift between states.
|
||||
|
||||
## Props
|
||||
|
||||
| Prop | Type | Default | Description |
|
||||
|------|------|---------|-------------|
|
||||
| `icon` | `IconFunctionComponent` | **required** | Left icon component |
|
||||
| `children` | `string` | **required** | Label text between icon and trailing indicator |
|
||||
| `active` | `boolean` | `false` | Whether the filter has an active selection |
|
||||
| `onClear` | `() => void` | **required** | Called when the clear (X) button is clicked |
|
||||
| `interaction` | `"rest" \| "hover" \| "active"` | auto | JS-controlled interaction override. Falls back to Radix `data-state="open"`. |
|
||||
| `tooltip` | `string` | — | Tooltip text shown on hover |
|
||||
| `tooltipSide` | `TooltipSide` | `"top"` | Which side the tooltip appears on |
|
||||
|
||||
## Usage
|
||||
|
||||
```tsx
|
||||
import { FilterButton } from "@opal/components";
|
||||
import { SvgUser } from "@opal/icons";
|
||||
|
||||
// Inside a Popover (auto-detects open state)
|
||||
<Popover.Trigger asChild>
|
||||
<FilterButton
|
||||
icon={SvgUser}
|
||||
active={hasSelection}
|
||||
onClear={() => clearSelection()}
|
||||
>
|
||||
{hasSelection ? selectionLabel : "Everyone"}
|
||||
</FilterButton>
|
||||
</Popover.Trigger>
|
||||
```
|
||||
@@ -1,120 +0,0 @@
|
||||
import {
|
||||
Interactive,
|
||||
type InteractiveStatefulInteraction,
|
||||
type InteractiveStatefulProps,
|
||||
} from "@opal/core";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { SvgX } from "@opal/icons";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
|
||||
import { ChevronIcon } from "@opal/components/buttons/chevron";
|
||||
import { Button } from "@opal/components/buttons/button/components";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface FilterButtonProps
|
||||
extends Omit<InteractiveStatefulProps, "variant" | "state"> {
|
||||
/** Left icon — always visible. */
|
||||
icon: IconFunctionComponent;
|
||||
|
||||
/** Label text between icon and trailing indicator. */
|
||||
children: string;
|
||||
|
||||
/** Whether the filter has an active selection. @default false */
|
||||
active?: boolean;
|
||||
|
||||
/** Called when the clear (X) button is clicked in active state. */
|
||||
onClear: () => void;
|
||||
|
||||
/** Tooltip text shown on hover. */
|
||||
tooltip?: string;
|
||||
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// FilterButton
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function FilterButton({
|
||||
icon: Icon,
|
||||
children,
|
||||
onClear,
|
||||
tooltip,
|
||||
tooltipSide = "top",
|
||||
active = false,
|
||||
interaction,
|
||||
...statefulProps
|
||||
}: FilterButtonProps) {
|
||||
// Derive open state: explicit prop > Radix data-state (injected via Slot chain)
|
||||
const dataState = (statefulProps as Record<string, unknown>)["data-state"] as
|
||||
| string
|
||||
| undefined;
|
||||
const resolvedInteraction: InteractiveStatefulInteraction =
|
||||
interaction ?? (dataState === "open" ? "hover" : "rest");
|
||||
|
||||
const button = (
|
||||
<div className="relative">
|
||||
<Interactive.Stateful
|
||||
{...statefulProps}
|
||||
variant="select-filter"
|
||||
interaction={resolvedInteraction}
|
||||
state={active ? "selected" : "empty"}
|
||||
>
|
||||
<Interactive.Container type="button">
|
||||
<div className="interactive-foreground flex flex-row items-center gap-1">
|
||||
{iconWrapper(Icon, "lg", true)}
|
||||
<span className="whitespace-nowrap font-main-ui-action">
|
||||
{children}
|
||||
</span>
|
||||
<div style={{ visibility: active ? "hidden" : "visible" }}>
|
||||
{iconWrapper(ChevronIcon, "lg", true)}
|
||||
</div>
|
||||
</div>
|
||||
</Interactive.Container>
|
||||
</Interactive.Stateful>
|
||||
|
||||
{active && (
|
||||
<div className="absolute right-2 top-1/2 -translate-y-1/2">
|
||||
{/* Force hover state so the X stays visually prominent against
|
||||
the inverted selected background — without this it renders
|
||||
dimmed and looks disabled. */}
|
||||
<Button
|
||||
icon={SvgX}
|
||||
size="2xs"
|
||||
prominence="tertiary"
|
||||
tooltip="Clear filter"
|
||||
interaction="hover"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onClear();
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
if (!tooltip) return button;
|
||||
|
||||
return (
|
||||
<TooltipPrimitive.Root>
|
||||
<TooltipPrimitive.Trigger asChild>{button}</TooltipPrimitive.Trigger>
|
||||
<TooltipPrimitive.Portal>
|
||||
<TooltipPrimitive.Content
|
||||
className="opal-tooltip"
|
||||
side={tooltipSide}
|
||||
sideOffset={4}
|
||||
>
|
||||
{tooltip}
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
);
|
||||
}
|
||||
|
||||
export { FilterButton, type FilterButtonProps };
|
||||
@@ -1,5 +1,8 @@
|
||||
import "@opal/components/tooltip.css";
|
||||
import {
|
||||
Interactive,
|
||||
type InteractiveStatefulState,
|
||||
type InteractiveStatefulInteraction,
|
||||
type InteractiveStatefulProps,
|
||||
InteractiveContainerRoundingVariant,
|
||||
} from "@opal/core";
|
||||
@@ -19,26 +22,40 @@ type ContentPassthroughProps = DistributiveOmit<
|
||||
"paddingVariant" | "widthVariant" | "ref" | "withInteractive"
|
||||
>;
|
||||
|
||||
type LineItemButtonOwnProps = Pick<
|
||||
InteractiveStatefulProps,
|
||||
| "state"
|
||||
| "interaction"
|
||||
| "onClick"
|
||||
| "href"
|
||||
| "target"
|
||||
| "group"
|
||||
| "ref"
|
||||
| "type"
|
||||
> & {
|
||||
type LineItemButtonOwnProps = {
|
||||
/** Interactive select variant. @default "select-light" */
|
||||
selectVariant?: "select-light" | "select-heavy";
|
||||
|
||||
/** Value state. @default "empty" */
|
||||
state?: InteractiveStatefulState;
|
||||
|
||||
/** JS-controllable interaction state override. @default "rest" */
|
||||
interaction?: InteractiveStatefulInteraction;
|
||||
|
||||
/** Click handler. */
|
||||
onClick?: InteractiveStatefulProps["onClick"];
|
||||
|
||||
/** When provided, renders an anchor instead of a div. */
|
||||
href?: string;
|
||||
|
||||
/** Anchor target (e.g. "_blank"). */
|
||||
target?: string;
|
||||
|
||||
/** Interactive group key. */
|
||||
group?: string;
|
||||
|
||||
/** Forwarded ref. */
|
||||
ref?: React.Ref<HTMLElement>;
|
||||
|
||||
/** Corner rounding preset (height is always content-driven). @default "default" */
|
||||
roundingVariant?: InteractiveContainerRoundingVariant;
|
||||
|
||||
/** Container width. @default "full" */
|
||||
width?: ExtremaSizeVariants;
|
||||
|
||||
/** HTML button type. @default "button" */
|
||||
type?: "submit" | "button" | "reset";
|
||||
|
||||
/** Tooltip text shown on hover. */
|
||||
tooltip?: string;
|
||||
|
||||
@@ -62,11 +79,11 @@ function LineItemButton({
|
||||
target,
|
||||
group,
|
||||
ref,
|
||||
type = "button",
|
||||
|
||||
// Sizing
|
||||
roundingVariant = "default",
|
||||
width = "full",
|
||||
type = "button",
|
||||
tooltip,
|
||||
tooltipSide = "top",
|
||||
|
||||
|
||||
@@ -40,6 +40,13 @@ export const Open: Story = {
|
||||
},
|
||||
};
|
||||
|
||||
export const Disabled: Story = {
|
||||
args: {
|
||||
disabled: true,
|
||||
children: "Disabled",
|
||||
},
|
||||
};
|
||||
|
||||
export const Foldable: Story = {
|
||||
args: {
|
||||
foldable: true,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import "@opal/components/buttons/open-button/styles.css";
|
||||
import "@opal/components/tooltip.css";
|
||||
import {
|
||||
Interactive,
|
||||
useDisabled,
|
||||
@@ -7,11 +9,24 @@ import {
|
||||
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
|
||||
import type { InteractiveContainerRoundingVariant } from "@opal/core";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import type { IconFunctionComponent, IconProps } from "@opal/types";
|
||||
import { SvgChevronDownSmall } from "@opal/icons";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
import { cn } from "@opal/utils";
|
||||
import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
|
||||
import { ChevronIcon } from "@opal/components/buttons/chevron";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Chevron (stable identity — never causes React to remount the SVG)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function ChevronIcon({ className, ...props }: IconProps) {
|
||||
return (
|
||||
<SvgChevronDownSmall
|
||||
className={cn(className, "opal-open-button-chevron")}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
.opal-open-button-chevron {
|
||||
transition: rotate 200ms ease;
|
||||
}
|
||||
|
||||
.interactive[data-interaction="hover"] .opal-open-button-chevron,
|
||||
.interactive[data-interaction="active"] .opal-open-button-chevron {
|
||||
rotate: -180deg;
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
import "@opal/components/buttons/select-button/styles.css";
|
||||
import "@opal/components/tooltip.css";
|
||||
import {
|
||||
Interactive,
|
||||
useDisabled,
|
||||
@@ -49,6 +50,9 @@ type SelectButtonProps = InteractiveStatefulProps &
|
||||
*/
|
||||
size?: ContainerSizeVariants;
|
||||
|
||||
/** HTML button type. Container renders a `<button>` element. */
|
||||
type?: "submit" | "button" | "reset";
|
||||
|
||||
/** Tooltip text shown on hover. */
|
||||
tooltip?: string;
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import "@opal/components/tooltip.css";
|
||||
|
||||
/* Shared types */
|
||||
export type TooltipSide = "top" | "bottom" | "left" | "right";
|
||||
|
||||
@@ -21,12 +19,6 @@ export {
|
||||
type OpenButtonProps,
|
||||
} from "@opal/components/buttons/open-button/components";
|
||||
|
||||
/* FilterButton */
|
||||
export {
|
||||
FilterButton,
|
||||
type FilterButtonProps,
|
||||
} from "@opal/components/buttons/filter-button/components";
|
||||
|
||||
/* LineItemButton */
|
||||
export {
|
||||
LineItemButton,
|
||||
|
||||
@@ -32,13 +32,7 @@ function ColumnVisibilityPopover<TData extends RowData>({
|
||||
// User-defined columns only (exclude internal qualifier/actions)
|
||||
const dataColumns = table
|
||||
.getAllLeafColumns()
|
||||
.filter(
|
||||
(col) =>
|
||||
!col.id.startsWith("__") &&
|
||||
col.id !== "qualifier" &&
|
||||
typeof col.columnDef.header === "string" &&
|
||||
col.columnDef.header.trim() !== ""
|
||||
);
|
||||
.filter((col) => !col.id.startsWith("__") && col.id !== "qualifier");
|
||||
|
||||
return (
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
|
||||
@@ -145,8 +145,6 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
pageSize,
|
||||
initialSorting,
|
||||
initialColumnVisibility,
|
||||
initialRowSelection,
|
||||
initialViewSelected,
|
||||
draggable,
|
||||
footer,
|
||||
size = "lg",
|
||||
@@ -223,8 +221,6 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
pageSize: effectivePageSize,
|
||||
initialSorting,
|
||||
initialColumnVisibility,
|
||||
initialRowSelection,
|
||||
initialViewSelected,
|
||||
getRowId,
|
||||
onSelectionChange,
|
||||
searchTerm,
|
||||
|
||||
@@ -103,10 +103,6 @@ interface UseDataTableOptions<TData extends RowData> {
|
||||
initialSorting?: SortingState;
|
||||
/** Initial column visibility state. @default {} */
|
||||
initialColumnVisibility?: VisibilityState;
|
||||
/** Initial row selection state. Keys are row IDs (from `getRowId`), values are `true`. @default {} */
|
||||
initialRowSelection?: RowSelectionState;
|
||||
/** When true AND `initialRowSelection` is non-empty, start in view-selected mode (filtered to selected rows). @default false */
|
||||
initialViewSelected?: boolean;
|
||||
/** Called whenever the set of selected row IDs changes. */
|
||||
onSelectionChange?: (selectedIds: string[]) => void;
|
||||
/** Search term for global text filtering. Rows are filtered to those containing
|
||||
@@ -199,8 +195,6 @@ export default function useDataTable<TData extends RowData>(
|
||||
columnResizeMode = "onChange",
|
||||
initialSorting = [],
|
||||
initialColumnVisibility = {},
|
||||
initialRowSelection = {},
|
||||
initialViewSelected = false,
|
||||
getRowId,
|
||||
onSelectionChange,
|
||||
searchTerm,
|
||||
@@ -212,8 +206,7 @@ export default function useDataTable<TData extends RowData>(
|
||||
|
||||
// ---- internal state -----------------------------------------------------
|
||||
const [sorting, setSorting] = useState<SortingState>(initialSorting);
|
||||
const [rowSelection, setRowSelection] =
|
||||
useState<RowSelectionState>(initialRowSelection);
|
||||
const [rowSelection, setRowSelection] = useState<RowSelectionState>({});
|
||||
const [columnSizing, setColumnSizing] = useState<ColumnSizingState>({});
|
||||
const [columnVisibility, setColumnVisibility] = useState<VisibilityState>(
|
||||
initialColumnVisibility
|
||||
@@ -223,12 +216,8 @@ export default function useDataTable<TData extends RowData>(
|
||||
pageSize: pageSizeOption,
|
||||
});
|
||||
/** Combined global filter: view-mode (selected IDs) + text search. */
|
||||
const initialSelectedIds =
|
||||
initialViewSelected && Object.keys(initialRowSelection).length > 0
|
||||
? new Set(Object.keys(initialRowSelection))
|
||||
: null;
|
||||
const [globalFilter, setGlobalFilter] = useState<GlobalFilterValue>({
|
||||
selectedIds: initialSelectedIds,
|
||||
selectedIds: null,
|
||||
searchTerm: "",
|
||||
});
|
||||
|
||||
@@ -395,31 +384,6 @@ export default function useDataTable<TData extends RowData>(
|
||||
: data.length;
|
||||
const isPaginated = isFinite(pagination.pageSize);
|
||||
|
||||
// ---- keep view-mode filter in sync with selection ----------------------
|
||||
// When in view-selected mode, deselecting a row should remove it from
|
||||
// the visible set so it disappears immediately.
|
||||
useEffect(() => {
|
||||
if (isServerSide) return;
|
||||
if (globalFilter.selectedIds == null) return;
|
||||
|
||||
const currentIds = new Set(Object.keys(rowSelection));
|
||||
// Remove any ID from the filter that is no longer selected
|
||||
let changed = false;
|
||||
const next = new Set<string>();
|
||||
globalFilter.selectedIds.forEach((id) => {
|
||||
if (currentIds.has(id)) {
|
||||
next.add(id);
|
||||
} else {
|
||||
changed = true;
|
||||
}
|
||||
});
|
||||
if (changed) {
|
||||
setGlobalFilter((prev) => ({ ...prev, selectedIds: next }));
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps -- only react to
|
||||
// selection changes while in view mode
|
||||
}, [rowSelection, isServerSide]);
|
||||
|
||||
// ---- selection change callback ------------------------------------------
|
||||
const isFirstRenderRef = useRef(true);
|
||||
const onSelectionChangeRef = useRef(onSelectionChange);
|
||||
@@ -428,10 +392,6 @@ export default function useDataTable<TData extends RowData>(
|
||||
useEffect(() => {
|
||||
if (isFirstRenderRef.current) {
|
||||
isFirstRenderRef.current = false;
|
||||
// Still fire the callback on first render if there's an initial selection
|
||||
if (selectedRowIds.length > 0) {
|
||||
onSelectionChangeRef.current?.(selectedRowIds);
|
||||
}
|
||||
return;
|
||||
}
|
||||
onSelectionChangeRef.current?.(selectedRowIds);
|
||||
|
||||
@@ -146,10 +146,6 @@ export interface DataTableProps<TData> {
|
||||
initialSorting?: SortingState;
|
||||
/** Initial column visibility state. */
|
||||
initialColumnVisibility?: VisibilityState;
|
||||
/** Initial row selection state. Keys are row IDs (from `getRowId`), values are `true`. */
|
||||
initialRowSelection?: Record<string, boolean>;
|
||||
/** When true AND `initialRowSelection` is non-empty, start in view-selected mode. @default false */
|
||||
initialViewSelected?: boolean;
|
||||
/** Enable drag-and-drop row reordering. */
|
||||
draggable?: DataTableDraggableConfig;
|
||||
/** Footer configuration. */
|
||||
|
||||
@@ -88,12 +88,9 @@ function HoverableRoot({
|
||||
ref,
|
||||
onMouseEnter: consumerMouseEnter,
|
||||
onMouseLeave: consumerMouseLeave,
|
||||
onFocusCapture: consumerFocusCapture,
|
||||
onBlurCapture: consumerBlurCapture,
|
||||
...props
|
||||
}: HoverableRootProps) {
|
||||
const [hovered, setHovered] = useState(false);
|
||||
const [focused, setFocused] = useState(false);
|
||||
|
||||
const onMouseEnter = useCallback(
|
||||
(e: React.MouseEvent<HTMLDivElement>) => {
|
||||
@@ -111,40 +108,16 @@ function HoverableRoot({
|
||||
[consumerMouseLeave]
|
||||
);
|
||||
|
||||
const onFocusCapture = useCallback(
|
||||
(e: React.FocusEvent<HTMLDivElement>) => {
|
||||
setFocused(true);
|
||||
consumerFocusCapture?.(e);
|
||||
},
|
||||
[consumerFocusCapture]
|
||||
);
|
||||
|
||||
const onBlurCapture = useCallback(
|
||||
(e: React.FocusEvent<HTMLDivElement>) => {
|
||||
if (
|
||||
!(e.relatedTarget instanceof Node) ||
|
||||
!e.currentTarget.contains(e.relatedTarget)
|
||||
) {
|
||||
setFocused(false);
|
||||
}
|
||||
consumerBlurCapture?.(e);
|
||||
},
|
||||
[consumerBlurCapture]
|
||||
);
|
||||
|
||||
const active = hovered || focused;
|
||||
const GroupContext = getOrCreateContext(group);
|
||||
|
||||
return (
|
||||
<GroupContext.Provider value={active}>
|
||||
<GroupContext.Provider value={hovered}>
|
||||
<div
|
||||
{...props}
|
||||
ref={ref}
|
||||
className={cn(widthVariants[widthVariant])}
|
||||
onMouseEnter={onMouseEnter}
|
||||
onMouseLeave={onMouseLeave}
|
||||
onFocusCapture={onFocusCapture}
|
||||
onBlurCapture={onBlurCapture}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
|
||||
@@ -16,15 +16,3 @@
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-local="true"]:hover {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Focus — item (or a focusable descendant) receives keyboard focus */
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"]:has(:focus-visible) {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Focus ring on keyboard focus */
|
||||
.hoverable-item:focus-visible {
|
||||
outline: 2px solid var(--border-04);
|
||||
outline-offset: 2px;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import type { Route } from "next";
|
||||
import "@opal/core/interactive/shared.css";
|
||||
import React from "react";
|
||||
import { cn } from "@opal/utils";
|
||||
import type { ButtonType, WithoutStyles } from "@opal/types";
|
||||
import type { WithoutStyles } from "@opal/types";
|
||||
import {
|
||||
containerSizeVariants,
|
||||
type ContainerSizeVariants,
|
||||
@@ -52,7 +52,7 @@ interface InteractiveContainerProps
|
||||
*
|
||||
* Mutually exclusive with `href`.
|
||||
*/
|
||||
type?: ButtonType;
|
||||
type?: "submit" | "button" | "reset";
|
||||
|
||||
/**
|
||||
* When `true`, applies a 1px border using the theme's border color.
|
||||
|
||||
@@ -8,7 +8,7 @@ Stateful interactive surface primitive for elements that maintain a value state
|
||||
|
||||
| Prop | Type | Default | Description |
|
||||
|------|------|---------|-------------|
|
||||
| `variant` | `"select-light" \| "select-heavy" \| "select-tinted" \| "select-filter" \| "sidebar"` | `"select-heavy"` | Color variant |
|
||||
| `variant` | `"select-light" \| "select-heavy" \| "sidebar"` | `"select-heavy"` | Color variant |
|
||||
| `state` | `"empty" \| "filled" \| "selected"` | `"empty"` | Current value state |
|
||||
| `interaction` | `"rest" \| "hover" \| "active"` | `"rest"` | JS-controlled interaction override |
|
||||
| `group` | `string` | — | Tailwind group class for `group-hover:*` |
|
||||
|
||||
@@ -4,7 +4,7 @@ import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useDisabled } from "@opal/core/disabled/components";
|
||||
import type { ButtonType, WithoutStyles } from "@opal/types";
|
||||
import type { WithoutStyles } from "@opal/types";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
@@ -14,7 +14,6 @@ type InteractiveStatefulVariant =
|
||||
| "select-light"
|
||||
| "select-heavy"
|
||||
| "select-tinted"
|
||||
| "select-filter"
|
||||
| "sidebar";
|
||||
type InteractiveStatefulState = "empty" | "filled" | "selected";
|
||||
type InteractiveStatefulInteraction = "rest" | "hover" | "active";
|
||||
@@ -31,8 +30,6 @@ interface InteractiveStatefulProps
|
||||
*
|
||||
* - `"select-light"` — transparent selected background (for inline toggles)
|
||||
* - `"select-heavy"` — tinted selected background (for list rows, model pickers)
|
||||
* - `"select-tinted"` — like select-heavy but with a tinted rest background
|
||||
* - `"select-filter"` — like select-tinted for empty/filled; selected state uses inverted tint backgrounds and inverted text (for filter buttons)
|
||||
* - `"sidebar"` — for sidebar navigation items
|
||||
*
|
||||
* @default "select-heavy"
|
||||
@@ -66,13 +63,6 @@ interface InteractiveStatefulProps
|
||||
*/
|
||||
group?: string;
|
||||
|
||||
/**
|
||||
* HTML button type. When set to `"submit"`, `"button"`, or `"reset"`, the
|
||||
* element is treated as inherently interactive for cursor styling purposes
|
||||
* even without an explicit `onClick` or `href`.
|
||||
*/
|
||||
type?: ButtonType;
|
||||
|
||||
/**
|
||||
* URL to navigate to when clicked. Passed through Slot to the child.
|
||||
*/
|
||||
@@ -104,7 +94,6 @@ function InteractiveStateful({
|
||||
state = "empty",
|
||||
interaction = "rest",
|
||||
group,
|
||||
type,
|
||||
href,
|
||||
target,
|
||||
...props
|
||||
@@ -115,7 +104,7 @@ function InteractiveStateful({
|
||||
// so Radix Slot-injected handlers don't bypass this guard.
|
||||
const classes = cn(
|
||||
"interactive",
|
||||
!props.onClick && !href && !type && "!cursor-default !select-auto",
|
||||
!props.onClick && !href && "!cursor-default !select-auto",
|
||||
group
|
||||
);
|
||||
|
||||
|
||||
@@ -308,89 +308,6 @@
|
||||
--interactive-foreground-icon: var(--action-link-03);
|
||||
}
|
||||
|
||||
/* ===========================================================================
|
||||
Select-Filter — empty/filled identical to Select-Tinted;
|
||||
selected uses inverted tint backgrounds and inverted text
|
||||
=========================================================================== */
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Select-Filter — Empty & Filled (identical colors)
|
||||
--------------------------------------------------------------------------- */
|
||||
.interactive[data-interactive-variant="select-filter"]:is(
|
||||
[data-interactive-state="empty"],
|
||||
[data-interactive-state="filled"]
|
||||
) {
|
||||
@apply bg-background-tint-01;
|
||||
--interactive-foreground: var(--text-02);
|
||||
--interactive-foreground-icon: var(--text-02);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-filter"]:is(
|
||||
[data-interactive-state="empty"],
|
||||
[data-interactive-state="filled"]
|
||||
):hover:not([data-disabled]),
|
||||
.interactive[data-interactive-variant="select-filter"]:is(
|
||||
[data-interactive-state="empty"],
|
||||
[data-interactive-state="filled"]
|
||||
)[data-interaction="hover"]:not([data-disabled]) {
|
||||
@apply bg-background-tint-02;
|
||||
--interactive-foreground: var(--text-04);
|
||||
--interactive-foreground-icon: var(--text-04);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-filter"]:is(
|
||||
[data-interactive-state="empty"],
|
||||
[data-interactive-state="filled"]
|
||||
):active:not([data-disabled]),
|
||||
.interactive[data-interactive-variant="select-filter"]:is(
|
||||
[data-interactive-state="empty"],
|
||||
[data-interactive-state="filled"]
|
||||
)[data-interaction="active"]:not([data-disabled]) {
|
||||
@apply bg-background-neutral-00;
|
||||
--interactive-foreground: var(--text-05);
|
||||
--interactive-foreground-icon: var(--text-05);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-filter"]:is(
|
||||
[data-interactive-state="empty"],
|
||||
[data-interactive-state="filled"]
|
||||
)[data-disabled] {
|
||||
@apply bg-transparent;
|
||||
--interactive-foreground: var(--text-01);
|
||||
--interactive-foreground-icon: var(--text-01);
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Select-Filter — Selected
|
||||
--------------------------------------------------------------------------- */
|
||||
.interactive[data-interactive-variant="select-filter"][data-interactive-state="selected"] {
|
||||
@apply bg-background-tint-inverted-03;
|
||||
--interactive-foreground: var(--text-inverted-05);
|
||||
--interactive-foreground-icon: var(--text-inverted-05);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-filter"][data-interactive-state="selected"]:hover:not(
|
||||
[data-disabled]
|
||||
),
|
||||
.interactive[data-interactive-variant="select-filter"][data-interactive-state="selected"][data-interaction="hover"]:not(
|
||||
[data-disabled]
|
||||
) {
|
||||
@apply bg-background-tint-inverted-04;
|
||||
--interactive-foreground: var(--text-inverted-05);
|
||||
--interactive-foreground-icon: var(--text-inverted-05);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-filter"][data-interactive-state="selected"]:active:not(
|
||||
[data-disabled]
|
||||
),
|
||||
.interactive[data-interactive-variant="select-filter"][data-interactive-state="selected"][data-interaction="active"]:not(
|
||||
[data-disabled]
|
||||
) {
|
||||
@apply bg-background-tint-inverted-04;
|
||||
--interactive-foreground: var(--text-inverted-04);
|
||||
--interactive-foreground-icon: var(--text-inverted-04);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-filter"][data-interactive-state="selected"][data-disabled] {
|
||||
@apply bg-background-neutral-04;
|
||||
--interactive-foreground: var(--text-inverted-04);
|
||||
--interactive-foreground-icon: var(--text-inverted-02);
|
||||
}
|
||||
|
||||
/* ===========================================================================
|
||||
Sidebar
|
||||
=========================================================================== */
|
||||
|
||||
@@ -4,7 +4,7 @@ import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useDisabled } from "@opal/core/disabled/components";
|
||||
import type { ButtonType, WithoutStyles } from "@opal/types";
|
||||
import type { WithoutStyles } from "@opal/types";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
@@ -53,13 +53,6 @@ interface InteractiveStatelessProps
|
||||
*/
|
||||
group?: string;
|
||||
|
||||
/**
|
||||
* HTML button type. When set to `"submit"`, `"button"`, or `"reset"`, the
|
||||
* element is treated as inherently interactive for cursor styling purposes
|
||||
* even without an explicit `onClick` or `href`.
|
||||
*/
|
||||
type?: ButtonType;
|
||||
|
||||
/**
|
||||
* URL to navigate to when clicked. Passed through Slot to the child.
|
||||
*/
|
||||
@@ -92,7 +85,6 @@ function InteractiveStateless({
|
||||
prominence = "primary",
|
||||
interaction = "rest",
|
||||
group,
|
||||
type,
|
||||
href,
|
||||
target,
|
||||
...props
|
||||
@@ -103,7 +95,7 @@ function InteractiveStateless({
|
||||
// so Radix Slot-injected handlers don't bypass this guard.
|
||||
const classes = cn(
|
||||
"interactive",
|
||||
!props.onClick && !href && !type && "!cursor-default !select-auto",
|
||||
!props.onClick && !href && "!cursor-default !select-auto",
|
||||
group
|
||||
);
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgEyeOff = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M11.78 11.78C10.6922 12.6092 9.36761 13.0685 8 13.0909C3.54545 13.0909 1 8 1 8C1.79157 6.52484 2.88945 5.23602 4.22 4.22M11.78 11.78L9.34909 9.34909M11.78 11.78L15 15M4.22 4.22L1 1M4.22 4.22L6.65091 6.65091M6.66364 3.06182C7.10167 2.95929 7.55013 2.90803 8 2.90909C12.4545 2.90909 15 8 15 8C14.6137 8.72266 14.153 9.40301 13.6255 10.03M9.34909 9.34909L6.65091 6.65091M9.34909 9.34909C8.99954 9.72422 8.49873 9.94737 7.98606 9.95641C6.922 9.97519 6.02481 9.078 6.04358 8.01394C6.05263 7.50127 6.27578 7.00046 6.65091 6.65091"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgEyeOff;
|
||||
@@ -1,21 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgFileBroadcast = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 18 18"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M6.1875 2.25003H2.625C1.808 2.25003 1.125 2.93303 1.125 3.75003L1.125 14.25C1.125 15.067 1.808 15.75 2.625 15.75L9.37125 15.75C10.1883 15.75 10.8713 15.067 10.8713 14.25L10.8713 6.94128M6.1875 2.25003L10.8713 6.94128M6.1875 2.25003V6.94128H10.8713M10.3069 2.25L13.216 5.15914C13.6379 5.5811 13.875 6.15339 13.875 6.75013V13.875C13.875 14.5212 13.737 15.2081 13.4392 15.7538M16.4391 15.7538C16.737 15.2081 16.875 14.5213 16.875 13.8751L16.875 7.02481C16.875 5.53418 16.2833 4.10451 15.23 3.04982L14.4301 2.25003"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgFileBroadcast;
|
||||
@@ -1,21 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgHookNodes = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M10.0002 4C10.0002 3.99708 10.0002 3.99415 10.0001 3.99123C9.99542 2.8907 9.10181 2 8.00016 2C6.89559 2 6.00016 2.89543 6.00016 4C6.00016 4.73701 6.39882 5.38092 6.99226 5.72784L4.67276 9.70412M11.6589 13.7278C11.9549 13.9009 12.2993 14 12.6668 14C13.7714 14 14.6668 13.1046 14.6668 12C14.6668 10.8954 13.7714 10 12.6668 10C12.2993 10 11.9549 10.0991 11.6589 10.2722L9.33943 6.29588M2.33316 10.2678C1.73555 10.6136 1.3335 11.2599 1.3335 12C1.3335 13.1046 2.22893 14 3.3335 14C4.43807 14 5.3335 13.1046 5.3335 12H10.0002"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgHookNodes;
|
||||
@@ -68,10 +68,8 @@ export { default as SvgExpand } from "@opal/icons/expand";
|
||||
export { default as SvgExternalLink } from "@opal/icons/external-link";
|
||||
export { default as SvgEye } from "@opal/icons/eye";
|
||||
export { default as SvgEyeClosed } from "@opal/icons/eye-closed";
|
||||
export { default as SvgEyeOff } from "@opal/icons/eye-off";
|
||||
export { default as SvgFileBraces } from "@opal/icons/file-braces";
|
||||
export { default as SvgFileBroadcast } from "@opal/icons/file-broadcast";
|
||||
export { default as SvgFiles } from "@opal/icons/files";
|
||||
export { default as SvgFileBraces } from "@opal/icons/file-braces";
|
||||
export { default as SvgFileChartPie } from "@opal/icons/file-chart-pie";
|
||||
export { default as SvgFileSmall } from "@opal/icons/file-small";
|
||||
export { default as SvgFileText } from "@opal/icons/file-text";
|
||||
@@ -91,7 +89,6 @@ export { default as SvgHashSmall } from "@opal/icons/hash-small";
|
||||
export { default as SvgHash } from "@opal/icons/hash";
|
||||
export { default as SvgHeadsetMic } from "@opal/icons/headset-mic";
|
||||
export { default as SvgHistory } from "@opal/icons/history";
|
||||
export { default as SvgHookNodes } from "@opal/icons/hook-nodes";
|
||||
export { default as SvgHourglass } from "@opal/icons/hourglass";
|
||||
export { default as SvgImage } from "@opal/icons/image";
|
||||
export { default as SvgImageSmall } from "@opal/icons/image-small";
|
||||
@@ -123,9 +120,7 @@ export { default as SvgNetworkGraph } from "@opal/icons/network-graph";
|
||||
export { default as SvgNotificationBubble } from "@opal/icons/notification-bubble";
|
||||
export { default as SvgOllama } from "@opal/icons/ollama";
|
||||
export { default as SvgOnyxLogo } from "@opal/icons/onyx-logo";
|
||||
export { default as SvgOnyxLogoTyped } from "@opal/icons/onyx-logo-typed";
|
||||
export { default as SvgOnyxOctagon } from "@opal/icons/onyx-octagon";
|
||||
export { default as SvgOnyxTyped } from "@opal/icons/onyx-typed";
|
||||
export { default as SvgOpenai } from "@opal/icons/openai";
|
||||
export { default as SvgOpenrouter } from "@opal/icons/openrouter";
|
||||
export { default as SvgOrganization } from "@opal/icons/organization";
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import SvgOnyxLogo from "@opal/icons/onyx-logo";
|
||||
import SvgOnyxTyped from "@opal/icons/onyx-typed";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
interface OnyxLogoTypedProps {
|
||||
size?: number;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
// # NOTE(@raunakab):
|
||||
// This ratio is not some random, magical number; it is available on Figma.
|
||||
const HEIGHT_TO_GAP_RATIO = 5 / 16;
|
||||
|
||||
const SvgOnyxLogoTyped = ({ size: height, className }: OnyxLogoTypedProps) => {
|
||||
const gap = height != null ? height * HEIGHT_TO_GAP_RATIO : undefined;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(`flex flex-row items-center`, className)}
|
||||
style={{ gap }}
|
||||
>
|
||||
<SvgOnyxLogo size={height} />
|
||||
<SvgOnyxTyped size={height} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
export default SvgOnyxLogoTyped;
|
||||
@@ -1,27 +1,19 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgOnyxLogo = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 64 64"
|
||||
fill="none"
|
||||
viewBox="0 0 56 56"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M10.4014 13.25L18.875 32L10.3852 50.75L2 32L10.4014 13.25Z"
|
||||
fill="var(--theme-primary-05)"
|
||||
/>
|
||||
<path
|
||||
d="M53.5264 13.25L62 32L53.5102 50.75L45.125 32L53.5264 13.25Z"
|
||||
fill="var(--theme-primary-05)"
|
||||
/>
|
||||
<path
|
||||
d="M32 45.125L50.75 53.5625L32 62L13.25 53.5625L32 45.125Z"
|
||||
fill="var(--theme-primary-05)"
|
||||
/>
|
||||
<path
|
||||
d="M32 2L50.75 10.4375L32 18.875L13.25 10.4375L32 2Z"
|
||||
fill="var(--theme-primary-05)"
|
||||
fillRule="evenodd"
|
||||
clipRule="evenodd"
|
||||
d="M28 0 10.869 7.77 28 15.539l17.131-7.77L28 0Zm0 40.461-17.131 7.77L28 56l17.131-7.77L28 40.461Zm20.231-29.592L56 28.001l-7.769 17.131L40.462 28l7.769-17.131ZM15.538 28 7.77 10.869 0 28l7.769 17.131L15.538 28Z"
|
||||
fill="currentColor"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgOnyxTyped = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
height={size}
|
||||
viewBox="0 0 152 64"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M19.1795 51.2136C15.6695 51.2136 12.4353 50.3862 9.47691 48.7315C6.56865 47.0768 4.2621 44.8454 2.55726 42.0374C0.85242 39.1793 0 36.0955 0 32.7861C0 30.279 0.451281 27.9223 1.35384 25.716C2.30655 23.4596 3.76068 21.3285 5.71623 19.3228L11.8085 13.08C12.4604 12.6789 13.4131 12.3529 14.6666 12.1022C15.9202 11.8014 17.2991 11.6509 18.8034 11.6509C22.3134 11.6509 25.5225 12.4783 28.4307 14.133C31.3891 15.7877 33.7208 18.0441 35.4256 20.9023C37.1304 23.7103 37.9829 26.794 37.9829 30.1536C37.9829 32.6106 37.5065 34.9673 36.5538 37.2237C35.6512 39.4802 34.147 41.6864 32.041 43.8426L26.3248 49.7845C25.3219 50.2358 24.2188 50.5868 23.0154 50.8375C21.8621 51.0882 20.5835 51.2136 19.1795 51.2136ZM20.1572 43.8426C21.8621 43.8426 23.4917 43.4164 25.0461 42.5639C26.6005 41.6614 27.8541 40.3577 28.8068 38.6528C29.8097 36.948 30.3111 34.9172 30.3111 32.5605C30.3111 30.0032 29.6843 27.6966 28.4307 25.6408C27.2273 23.5849 25.6478 21.9803 23.6923 20.8271C21.7869 19.6236 19.8313 19.0219 17.8256 19.0219C16.0706 19.0219 14.4159 19.4732 12.8615 20.3758C11.3573 21.2282 10.1288 22.5068 9.17606 24.2117C8.22335 25.9166 7.747 27.9473 7.747 30.304C7.747 32.8613 8.34871 35.1679 9.55212 37.2237C10.7555 39.2796 12.31 40.9092 14.2154 42.1127C16.1709 43.2659 18.1515 43.8426 20.1572 43.8426Z"
|
||||
fill="var(--theme-primary-05)"
|
||||
/>
|
||||
<path
|
||||
d="M42.6413 50.4614V12.4031H50.6891V17.7433L55.5028 12.7039C56.0544 12.4532 56.8065 12.2276 57.7592 12.027C58.7621 11.7763 59.8903 11.6509 61.1438 11.6509C64.0521 11.6509 66.5843 12.3028 68.7404 13.6065C70.9467 14.8601 72.6264 16.6401 73.7797 18.9467C74.9831 21.2533 75.5848 23.961 75.5848 27.0698V50.4614H67.6122V29.1006C67.6122 26.9946 67.2612 25.1895 66.5592 23.6852C65.9074 22.1308 64.9547 20.9775 63.7011 20.2253C62.4977 19.4231 61.0686 19.0219 59.4139 19.0219C56.7564 19.0219 54.6253 19.9245 53.0208 21.7296C51.4663 23.4846 50.6891 25.9416 50.6891 29.1006V50.4614H42.6413Z"
|
||||
fill="var(--theme-primary-05)"
|
||||
/>
|
||||
<path
|
||||
d="M82.3035 64V56.0273H89.9753C91.2288 56.0273 92.2066 55.7264 92.9086 55.1247C93.6607 54.523 94.2625 53.5452 94.7137 52.1913L108.027 12.4031H116.751L103.664 49.4084C103.062 51.1634 102.461 52.5173 101.859 53.47C101.307 54.4227 100.53 55.4506 99.5274 56.5538L92.4573 64H82.3035ZM90.7274 46.6255L76.9633 12.4031H85.989L99.4522 46.6255H90.7274Z"
|
||||
fill="var(--theme-primary-05)"
|
||||
/>
|
||||
<path
|
||||
d="M115.657 50.4614L129.045 31.2066L116.033 12.4031H125.435L134.085 24.8134L142.358 12.4031H151.308L138.372 31.0562L151.684 50.4614H142.358L133.332 37.3742L124.683 50.4614H115.657Z"
|
||||
fill="var(--theme-primary-05)"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgOnyxTyped;
|
||||
@@ -32,8 +32,6 @@ interface ContentMdPresetConfig {
|
||||
optionalFont: string;
|
||||
/** Aux icon size = lineHeight − 2 × p-0.5. */
|
||||
auxIconSize: string;
|
||||
/** Left indent for the description so it aligns with the title (past the icon). */
|
||||
descriptionIndent: string;
|
||||
}
|
||||
|
||||
interface ContentMdProps {
|
||||
@@ -87,7 +85,6 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-main-content-muted",
|
||||
auxIconSize: "1.25rem",
|
||||
descriptionIndent: "1.625rem",
|
||||
},
|
||||
"main-ui": {
|
||||
iconSize: "1rem",
|
||||
@@ -100,7 +97,6 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-main-ui-muted",
|
||||
auxIconSize: "1rem",
|
||||
descriptionIndent: "1.375rem",
|
||||
},
|
||||
secondary: {
|
||||
iconSize: "0.75rem",
|
||||
@@ -113,7 +109,6 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-secondary-action",
|
||||
auxIconSize: "0.75rem",
|
||||
descriptionIndent: "1.125rem",
|
||||
},
|
||||
};
|
||||
|
||||
@@ -168,25 +163,22 @@ function ContentMd({
|
||||
data-interactive={withInteractive || undefined}
|
||||
style={{ gap: config.gap }}
|
||||
>
|
||||
<div
|
||||
className="opal-content-md-header"
|
||||
data-editing={editing || undefined}
|
||||
>
|
||||
{Icon && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-md-icon-container shrink-0",
|
||||
config.iconContainerPadding
|
||||
)}
|
||||
style={{ minHeight: config.lineHeight }}
|
||||
>
|
||||
<Icon
|
||||
className={cn("opal-content-md-icon", config.iconColorClass)}
|
||||
style={{ width: config.iconSize, height: config.iconSize }}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{Icon && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-md-icon-container shrink-0",
|
||||
config.iconContainerPadding
|
||||
)}
|
||||
style={{ minHeight: config.lineHeight }}
|
||||
>
|
||||
<Icon
|
||||
className={cn("opal-content-md-icon", config.iconColorClass)}
|
||||
style={{ width: config.iconSize, height: config.iconSize }}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="opal-content-md-body">
|
||||
<div className="opal-content-md-title-row">
|
||||
{editing ? (
|
||||
<div className="opal-content-md-input-sizer">
|
||||
@@ -282,16 +274,13 @@ function ContentMd({
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{description && (
|
||||
<div
|
||||
className="opal-content-md-description font-secondary-body text-text-03"
|
||||
style={Icon ? { paddingLeft: config.descriptionIndent } : undefined}
|
||||
>
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
{description && (
|
||||
<div className="opal-content-md-description font-secondary-body text-text-03">
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -224,16 +224,7 @@
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-md {
|
||||
@apply flex flex-col items-start;
|
||||
}
|
||||
|
||||
.opal-content-md-header {
|
||||
@apply flex flex-row items-center w-full;
|
||||
}
|
||||
|
||||
.opal-content-md-header[data-editing] {
|
||||
@apply rounded-08;
|
||||
box-shadow: inset 0 0 0 1px var(--border-02);
|
||||
@apply flex flex-row items-start;
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
@@ -246,6 +237,15 @@
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Body column
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
.opal-content-md-body {
|
||||
@apply flex flex-1 flex-col items-start;
|
||||
min-width: 0.0625rem;
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Title row — title (or input) + edit button
|
||||
--------------------------------------------------------------------------- */
|
||||
@@ -267,7 +267,6 @@
|
||||
.opal-content-md-input-sizer {
|
||||
display: inline-grid;
|
||||
align-items: stretch;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.opal-content-md-input-sizer > * {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user