Compare commits

..

42 Commits

Author SHA1 Message Date
Nik
cf403d9c89 feat(chat): carousel peek animation and interface cleanup for multi-model UI
- Extract MultiModelResponse interface to message/interfaces.ts (co-locate
  per CLAUDE.md convention; update multi-model-preview/page.tsx import)
- Carousel/fade in selection mode: non-preferred panels animate from adjacent
  to preferred → peek position (64px) at container edges, with mask-image
  gradient fade so they dissolve naturally at the viewport boundary
- Entry animation: panels slide out via CSS cubic-bezier transition triggered
  by requestAnimationFrame after mount, preventing initial-position flicker
- Fix test import: useMultiModelChat.test.tsx now imports renderHook+act from
  @tests/setup/test-utils (project wrapper) instead of @testing-library/react
- Revert FrostedDiv.tsx: remove unnecessary wrapperClassName prop (unused by
  any multi-model component)
- Add multi-model-preview dev page (was untracked)
2026-03-25 23:47:17 -07:00
Nik
ce7e68f671 fix(chat): fix hover state, dedup MAX_MODELS, clean up ModelSelector
- Replace Hoverable.Root/Item overlay with hover:bg-background-tint-02
  directly on panel container — correct Figma hover state (full-panel
  tint, no badge)
- Export MAX_MODELS from ModelSelector and import it in useMultiModelChat
  to eliminate the duplicate constant
- Replace @/components/ui/accordion import with @radix-ui/react-accordion
  (removes legacy directory dep)
- Remove dead "Compare Model" button with no onClick handler
- Remove SvgColumn import that was only used by the dead button
- Remove wrapper div around hidden MultiModelPanel (panel already
  self-sizes to w-[220px])
- Tighten panel max-width from 720px to 640px to match chat column width
2026-03-25 23:19:12 -07:00
Nik
65c5d5d5d9 feat(chat): add multi-model UI components and hook 2026-03-25 21:33:03 -07:00
Nik
ebe558e04f feat(chat): add frontend types and API helpers for multi-model streaming 2026-03-25 20:41:13 -07:00
Nik
a49edf3e18 feat(chat): add multi-model parallel streaming backend 2026-03-25 20:17:38 -07:00
Nikolas Garza
eb1244a9d7 feat(chat): add DB schema and Pydantic models for multi-model answers (#9646) 2026-03-26 02:21:00 +00:00
Evan Lohn
2433a9a4c5 feat: sharepoint filters (denylist) (#9649) 2026-03-26 01:33:18 +00:00
dependabot[bot]
60bc8fcac6 chore(deps): bump nltk from 3.9.3 to 3.9.4 (#9663)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-26 00:50:52 +00:00
dependabot[bot]
1ddc958a51 chore(deps): bump picomatch in /backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web (#9662)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-25 17:54:28 -07:00
acaprau
de37acbe07 chore(opensearch): Optimize terms filters; add type aliases (#9619) 2026-03-26 00:35:53 +00:00
Wenxi
08cd2f2c3e fix(ci): tag web-server and model-server with craft-latest (#9661)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 00:35:37 +00:00
acaprau
fc29f20914 feat(opensearch): Add Prometheus metrics for OpenSearch retrieval (#9654) 2026-03-26 00:29:29 +00:00
dependabot[bot]
c43cb80a7a chore(deps): bump yaml from 1.10.2 to 1.10.3 in /web (#9655)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-25 23:59:17 +00:00
dependabot[bot]
56f0be2ec8 chore(deps): bump requests from 2.32.5 to 2.33.0 (#9652)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-25 23:59:00 +00:00
acaprau
42f9ddf247 feat(opensearch): Search UI search flow can be configured to use pure keyword search (#9500)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-25 23:56:32 +00:00
dependabot[bot]
a10a85c73c chore(deps-dev): bump picomatch from 4.0.3 to 4.0.4 in /widget (#9659)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-25 17:01:38 -07:00
Jamison Lahman
31d8ae9718 chore(playwright): rework admin navigation tests (#9650)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-25 23:27:08 +00:00
Nikolas Garza
00a0a99842 fix: clarify service account API key upgrade message for trial accounts (#9581) 2026-03-25 23:22:45 +00:00
dependabot[bot]
90040f8973 chore(deps): bump picomatch in /examples/widget (#9651)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-25 16:20:36 -07:00
Raunak Bhagat
4f5d081f26 feat(opal): add Text component with inline markdown support (#9623) 2026-03-25 23:06:18 +00:00
dependabot[bot]
c51a6dbd0d chore(deps): bump pypdf from 6.9.1 to 6.9.2 (#9637)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-25 23:04:27 +00:00
Evan Lohn
8b90ecc189 feat: sharepoint shareable links non-public (#9636) 2026-03-25 22:50:29 +00:00
Justin Tahara
865c893a09 chore(agents): Match Mocks & Add Date Validation (#9632) 2026-03-25 21:57:31 +00:00
Bo-Onyx
ef5628bfa7 feat(hook): Frontend hook infrastructure (#9634) 2026-03-25 21:38:04 +00:00
Jessica Singh
6ffee0021e chore(voice): align fe with other admin pages (#9505) 2026-03-25 20:00:36 +00:00
Jessica Singh
28dc84b831 fix(notion): upgrade API version + logical changes (#9609)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 19:18:51 +00:00
Jamison Lahman
230f035500 fix(chat): dont clear input message after errors submitting (#9624) 2026-03-25 12:00:23 -07:00
Jamison Lahman
55b24d72b4 fix(fe): redirect to status page after deleting connector (#9620) 2026-03-25 17:24:41 +00:00
Raunak Bhagat
3321a84c7d fix(sidebar): fix icon alignment for user-avatar-popover (#9615)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-25 17:07:50 +00:00
SubashMohan
54bf32a5f8 fix: use persisted source functions when toggling search tool (#9548) 2026-03-25 16:50:25 +00:00
Nikolas Garza
4bb6b76be6 feat(groups): switchover to /admin/groups and rewrite e2e tests (#9545) 2026-03-25 08:11:13 +00:00
SubashMohan
db94562474 feat: Group-based permissions — Phase 1 schema (AccountType, Permission, PermissionGrant) (#9547) 2026-03-25 06:24:43 +00:00
Nikolas Garza
582d4642c1 feat(metrics): add task lifecycle and per-connector Prometheus metrics (#9602) 2026-03-25 06:02:43 +00:00
Nikolas Garza
3caaecdb0e feat(groups): polish edit page table and delete UX (#9544) 2026-03-25 04:57:50 +00:00
Nikolas Garza
039b69806b feat(metrics): add queue depth and connector health Prometheus collectors (#9590) 2026-03-25 03:53:26 +00:00
Evan Lohn
63971d4958 fix: confluence client retries (#9605) 2026-03-25 03:32:29 +00:00
Nikolas Garza
ffd897f380 feat(metrics): add reusable Prometheus metrics server for celery workers (#9589) 2026-03-25 01:47:06 +00:00
Evan Lohn
4745069232 fix: no more lazy queries per search call (#9578) 2026-03-25 01:38:35 +00:00
Nikolas Garza
386782f188 feat(groups): add edit group page (#9543) 2026-03-25 01:22:57 +00:00
Raunak Bhagat
ff009c4129 fix: Fix tag widths (#9618) 2026-03-25 01:18:51 +00:00
Bo-Onyx
b20a5ebf69 feat(hook): Add frontend feature control and admin hook page (#9575) 2026-03-25 00:37:37 +00:00
Bo-Onyx
8645adb807 fix(width): UI update model width definition. (#9613) 2026-03-25 00:11:32 +00:00
204 changed files with 10992 additions and 2825 deletions

View File

@@ -615,6 +615,7 @@ jobs:
tags: |
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'craft-latest' || '' }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
@@ -1263,8 +1264,6 @@ jobs:
latest=false
tags: |
type=raw,value=craft-latest
# TODO: Consider aligning craft-latest tags with regular backend builds (e.g., latest, edge, beta)
# to keep tagging strategy consistent across all backend images
- name: Create and push manifest
env:
@@ -1488,6 +1487,7 @@ jobs:
tags: |
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'craft-latest' || '' }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}

12
.vscode/launch.json vendored
View File

@@ -117,7 +117,8 @@
"presentation": {
"group": "2"
},
"consoleTitle": "API Server Console"
"consoleTitle": "API Server Console",
"justMyCode": false
},
{
"name": "Slack Bot",
@@ -268,7 +269,8 @@
"presentation": {
"group": "2"
},
"consoleTitle": "Celery heavy Console"
"consoleTitle": "Celery heavy Console",
"justMyCode": false
},
{
"name": "Celery kg_processing",
@@ -355,7 +357,8 @@
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user_file_processing Console"
"consoleTitle": "Celery user_file_processing Console",
"justMyCode": false
},
{
"name": "Celery docfetching",
@@ -413,7 +416,8 @@
"presentation": {
"group": "2"
},
"consoleTitle": "Celery docprocessing Console"
"consoleTitle": "Celery docprocessing Console",
"justMyCode": false
},
{
"name": "Celery beat",

View File

@@ -0,0 +1,109 @@
"""group_permissions_phase1
Revision ID: 25a5501dc766
Revises: b728689f45b1
Create Date: 2026-03-23 11:41:25.557442
"""
from alembic import op
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
from onyx.db.enums import AccountType
from onyx.db.enums import GrantSource
from onyx.db.enums import Permission
# revision identifiers, used by Alembic.
revision = "25a5501dc766"
down_revision = "b728689f45b1"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1. Add account_type column to user table (nullable for now).
# TODO(subash): backfill account_type for existing rows and add NOT NULL.
op.add_column(
"user",
sa.Column(
"account_type",
sa.Enum(AccountType, native_enum=False),
nullable=True,
),
)
# 2. Add is_default column to user_group table
op.add_column(
"user_group",
sa.Column(
"is_default",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# 3. Create permission_grant table
op.create_table(
"permission_grant",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("group_id", sa.Integer(), nullable=False),
sa.Column(
"permission",
sa.Enum(Permission, native_enum=False),
nullable=False,
),
sa.Column(
"grant_source",
sa.Enum(GrantSource, native_enum=False),
nullable=False,
),
sa.Column(
"granted_by",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column(
"granted_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"is_deleted",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["group_id"],
["user_group.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["granted_by"],
["user.id"],
ondelete="SET NULL",
),
sa.UniqueConstraint(
"group_id", "permission", name="uq_permission_grant_group_permission"
),
)
# 4. Index on user__user_group(user_id) — existing composite PK
# has user_group_id as leading column; user-filtered queries need this
op.create_index(
"ix_user__user_group_user_id",
"user__user_group",
["user_id"],
)
def downgrade() -> None:
op.drop_index("ix_user__user_group_user_id", table_name="user__user_group")
op.drop_table("permission_grant")
op.drop_column("user_group", "is_default")
op.drop_column("user", "account_type")

View File

@@ -0,0 +1,36 @@
"""add preferred_response_id and model_display_name to chat_message
Revision ID: a3f8b2c1d4e5
Create Date: 2026-03-22
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a3f8b2c1d4e5"
down_revision = "25a5501dc766"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column(
"preferred_response_id",
sa.Integer(),
sa.ForeignKey("chat_message.id", ondelete="SET NULL"),
nullable=True,
),
)
op.add_column(
"chat_message",
sa.Column("model_display_name", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_message", "model_display_name")
op.drop_column("chat_message", "preferred_response_id")

View File

@@ -115,8 +115,14 @@ def fetch_user_group_token_rate_limits_for_user(
ordered: bool = True,
get_editable: bool = True,
) -> Sequence[TokenRateLimit]:
stmt = select(TokenRateLimit)
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
stmt = (
select(TokenRateLimit)
.join(
TokenRateLimit__UserGroup,
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
)
.where(TokenRateLimit__UserGroup.user_group_id == group_id)
)
stmt = _add_user_filters(stmt, user, get_editable)
if enabled_only:

View File

@@ -250,20 +250,24 @@ def _get_sharepoint_list_item_id(drive_item: DriveItem) -> str | None:
raise e
def _is_public_item(drive_item: DriveItem) -> bool:
is_public = False
def _is_public_item(
drive_item: DriveItem,
treat_sharing_link_as_public: bool = False,
) -> bool:
if not treat_sharing_link_as_public:
return False
try:
permissions = sleep_and_retry(
drive_item.permissions.get_all(page_loaded=lambda _: None), "is_public_item"
)
for permission in permissions:
if permission.link and (
permission.link.scope == "anonymous"
or permission.link.scope == "organization"
if permission.link and permission.link.scope in (
"anonymous",
"organization",
):
is_public = True
break
return is_public
return True
return False
except Exception as e:
logger.error(f"Failed to check if item {drive_item.id} is public: {e}")
return False
@@ -504,6 +508,7 @@ def get_external_access_from_sharepoint(
drive_item: DriveItem | None,
site_page: dict[str, Any] | None,
add_prefix: bool = False,
treat_sharing_link_as_public: bool = False,
) -> ExternalAccess:
"""
Get external access information from SharePoint.
@@ -563,8 +568,7 @@ def get_external_access_from_sharepoint(
)
if drive_item and drive_name:
# Here we check if the item have have any public links, if so we return early
is_public = _is_public_item(drive_item)
is_public = _is_public_item(drive_item, treat_sharing_link_as_public)
if is_public:
logger.info(f"Item {drive_item.id} is public")
return ExternalAccess(

View File

@@ -44,19 +44,21 @@ def _run_single_search(
user: User,
db_session: Session,
num_hits: int | None = None,
hybrid_alpha: float | None = None,
) -> list[InferenceChunk]:
"""Execute a single search query and return chunks."""
chunk_search_request = ChunkSearchRequest(
query=query,
user_selected_filters=filters,
limit=num_hits,
hybrid_alpha=hybrid_alpha,
)
return search_pipeline(
chunk_search_request=chunk_search_request,
document_index=document_index,
user=user,
persona=None, # No persona for direct search
persona_search_info=None,
db_session=db_session,
)
@@ -74,7 +76,7 @@ def stream_search_query(
Core search function that yields streaming packets.
Used by both streaming and non-streaming endpoints.
"""
# Get document index
# Get document index.
search_settings = get_current_search_settings(db_session)
# This flow is for search so we do not get all indices.
document_index = get_default_document_index(search_settings, None, db_session)
@@ -119,6 +121,7 @@ def stream_search_query(
user=user,
db_session=db_session,
num_hits=request.num_hits,
hybrid_alpha=request.hybrid_alpha,
)
else:
# Multiple queries - run in parallel and merge with RRF
@@ -133,6 +136,7 @@ def stream_search_query(
user,
db_session,
request.num_hits,
request.hybrid_alpha,
),
)
for query in all_executed_queries

View File

@@ -27,15 +27,17 @@ class SearchFlowClassificationResponse(BaseModel):
is_search_flow: bool
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
# NOTE: This model is used for the core flow of the Onyx application, any
# changes to it should be reviewed and approved by an experienced team member.
# It is very important to 1. avoid bloat and 2. that this remains backwards
# compatible across versions.
class SendSearchQueryRequest(BaseModel):
search_query: str
filters: BaseFilters | None = None
num_docs_fed_to_llm_selection: int | None = None
run_query_expansion: bool = False
num_hits: int = 30
hybrid_alpha: float | None = None
include_content: bool = False
stream: bool = False

View File

@@ -20,6 +20,7 @@ from ee.onyx.server.query_and_chat.models import SearchQueryResponse
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
from onyx.auth.users import current_user
from onyx.configs.app_configs import ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import User
@@ -67,8 +68,10 @@ def search_flow_classification(
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
# NOTE: This endpoint is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
# NOTE: This endpoint is used for the core flow of the Onyx application, any
# changes to it should be reviewed and approved by an experienced team member.
# It is very important to 1. avoid bloat and 2. that this remains backwards
# compatible across versions.
@router.post(
"/send-search-message",
response_model=None,
@@ -80,13 +83,19 @@ def handle_send_search_message(
db_session: Session = Depends(get_session),
) -> StreamingResponse | SearchFullResponse:
"""
Execute a search query with optional streaming.
Executes a search query with optional streaming.
When stream=True: Returns StreamingResponse with SSE
When stream=False: Returns SearchFullResponse
If hybrid_alpha is unset and ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH
is True, executes pure keyword search.
Returns:
StreamingResponse with SSE if stream=True, otherwise SearchFullResponse.
"""
logger.debug(f"Received search query: {request.search_query}")
if request.hybrid_alpha is None and ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH:
request.hybrid_alpha = 0.0
# Non-streaming path
if not request.stream:
try:

View File

@@ -13,6 +13,14 @@ from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
from onyx.server.metrics.metrics_server import start_metrics_server
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -34,6 +42,8 @@ def on_task_prerun(
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
on_celery_task_prerun(task_id, task)
on_indexing_task_prerun(task_id, task, kwargs)
@signals.task_postrun.connect
@@ -48,6 +58,36 @@ def on_task_postrun(
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
on_celery_task_postrun(task_id, task, state)
on_indexing_task_postrun(task_id, task, kwargs, state)
@signals.task_retry.connect
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
# task_retry signal doesn't pass task_id in kwargs; get it from
# the sender (the task instance) via sender.request.id.
task_id = getattr(getattr(sender, "request", None), "id", None)
on_celery_task_retry(task_id, sender)
@signals.task_revoked.connect
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
task_name = getattr(sender, "name", None) or str(sender)
on_celery_task_revoked(kwargs.get("task_id"), task_name)
@signals.task_rejected.connect
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
# task_rejected sends the Consumer as sender, not the task instance.
# The task name must be extracted from the Celery message headers.
message = kwargs.get("message")
task_name: str | None = None
if message is not None:
headers = getattr(message, "headers", None) or {}
task_name = headers.get("task")
if task_name is None:
task_name = "unknown"
on_celery_task_rejected(None, task_name)
@celeryd_init.connect
@@ -76,6 +116,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
start_metrics_server("docfetching")
app_base.on_worker_ready(sender, **kwargs)

View File

@@ -14,6 +14,14 @@ from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
from onyx.server.metrics.metrics_server import start_metrics_server
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -35,6 +43,8 @@ def on_task_prerun(
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
on_celery_task_prerun(task_id, task)
on_indexing_task_prerun(task_id, task, kwargs)
@signals.task_postrun.connect
@@ -49,6 +59,36 @@ def on_task_postrun(
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
on_celery_task_postrun(task_id, task, state)
on_indexing_task_postrun(task_id, task, kwargs, state)
@signals.task_retry.connect
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
# task_retry signal doesn't pass task_id in kwargs; get it from
# the sender (the task instance) via sender.request.id.
task_id = getattr(getattr(sender, "request", None), "id", None)
on_celery_task_retry(task_id, sender)
@signals.task_revoked.connect
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
task_name = getattr(sender, "name", None) or str(sender)
on_celery_task_revoked(kwargs.get("task_id"), task_name)
@signals.task_rejected.connect
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
# task_rejected sends the Consumer as sender, not the task instance.
# The task name must be extracted from the Celery message headers.
message = kwargs.get("message")
task_name: str | None = None
if message is not None:
headers = getattr(message, "headers", None) or {}
task_name = headers.get("task")
if task_name is None:
task_name = "unknown"
on_celery_task_rejected(None, task_name)
@celeryd_init.connect
@@ -82,6 +122,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
start_metrics_server("docprocessing")
app_base.on_worker_ready(sender, **kwargs)
@@ -90,6 +131,12 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
# Note: worker_process_init only fires in prefork pool mode. Docprocessing uses
# worker_pool="threads" (see configs/docprocessing.py), so this handler is
# effectively a no-op in normal operation. It remains as a safety net in case
# the pool type is ever changed to prefork. Prometheus metrics are safe in
# thread-pool mode since all threads share the same process memory and can
# update the same Counter/Gauge/Histogram objects directly.
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
SqlEngine.reset_engine()

View File

@@ -54,8 +54,14 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
app_base.on_celeryd_init(sender, conf, **kwargs)
# Set by on_worker_init so on_worker_ready knows whether to start the server.
_prometheus_collectors_ok: bool = False
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
global _prometheus_collectors_ok
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
@@ -65,6 +71,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
_prometheus_collectors_ok = _setup_prometheus_collectors(sender)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
@@ -72,8 +80,37 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
app_base.on_secondary_worker_init(sender, **kwargs)
def _setup_prometheus_collectors(sender: Any) -> bool:
"""Register Prometheus collectors that need Redis/DB access.
Passes the Celery app so the queue depth collector can obtain a fresh
broker Redis client on each scrape (rather than holding a stale reference).
Returns True if registration succeeded, False otherwise.
"""
try:
from onyx.server.metrics.indexing_pipeline_setup import (
setup_indexing_pipeline_metrics,
)
setup_indexing_pipeline_metrics(sender.app)
logger.info("Prometheus indexing pipeline collectors registered")
return True
except Exception:
logger.exception("Failed to register Prometheus indexing pipeline collectors")
return False
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
if _prometheus_collectors_ok:
from onyx.server.metrics.metrics_server import start_metrics_server
start_metrics_server("monitoring")
else:
logger.warning(
"Skipping Prometheus metrics server — collector registration failed"
)
app_base.on_worker_ready(sender, **kwargs)

View File

@@ -8,6 +8,7 @@ from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import InMemoryChatFile
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import Packet
@@ -35,7 +36,13 @@ class CreateChatSessionID(BaseModel):
chat_session_id: UUID
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
AnswerStreamPart = (
Packet
| MessageResponseIDInfo
| MultiModelMessageResponseIDInfo
| StreamingError
| CreateChatSessionID
)
AnswerStream = Iterator[AnswerStreamPart]

View File

@@ -4,9 +4,11 @@ An overview can be found in the README.md file in this directory.
"""
import io
import queue
import re
import traceback
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from contextvars import Token
from uuid import UUID
@@ -28,6 +30,7 @@ from onyx.chat.compression import calculate_total_history_tokens
from onyx.chat.compression import compress_chat_history
from onyx.chat.compression import find_summary_for_branch
from onyx.chat.compression import get_compression_params
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.chat.llm_loop import EmptyLLMResponseError
from onyx.chat.llm_loop import run_llm_loop
@@ -59,6 +62,8 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import reserve_multi_model_message_ids
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
@@ -86,16 +91,21 @@ from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.override_models import LLMOverride
from onyx.llm.request_context import reset_llm_mock_response
from onyx.llm.request_context import set_llm_mock_response
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import ModelResponseSlot
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
from onyx.tools.constants import SEARCH_TOOL_ID
@@ -1069,6 +1079,583 @@ def handle_stream_message_objects(
logger.exception("Error in setting processing status")
def _build_model_display_name(override: LLMOverride) -> str:
"""Build a human-readable display name from an LLM override."""
if override.display_name:
return override.display_name
if override.model_version:
return override.model_version
if override.model_provider:
return override.model_provider
return "unknown"
# Sentinel placed on the merged queue when a model thread finishes.
_MODEL_DONE = object()
class _ModelIndexEmitter(Emitter):
"""Emitter that tags packets with model_index and forwards directly to a shared queue.
Unlike the standard Emitter (which accumulates in a local bus), this puts
packets into the shared merged_queue in real-time as they're emitted. This
enables true parallel streaming — packets from multiple models interleave
on the wire instead of arriving in bursts after each model completes.
"""
def __init__(self, model_idx: int, merged_queue: queue.Queue) -> None:
super().__init__(queue.Queue()) # bus exists for compat, unused
self._model_idx = model_idx
self._merged_queue = merged_queue
def emit(self, packet: Packet) -> None:
tagged_placement = Placement(
turn_index=packet.placement.turn_index if packet.placement else 0,
tab_index=packet.placement.tab_index if packet.placement else 0,
sub_turn_index=(
packet.placement.sub_turn_index if packet.placement else None
),
model_index=self._model_idx,
)
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
self._merged_queue.put((self._model_idx, tagged_packet))
def run_multi_model_stream(
new_msg_req: SendMessageRequest,
user: User,
db_session: Session,
llm_overrides: list[LLMOverride],
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
mcp_headers: dict[str, str] | None = None,
) -> AnswerStream:
# TODO(ENG-3888): The setup logic below (session resolution through tool construction)
# is duplicated from handle_stream_message_objects. Extract into a shared
# _ChatStreamContext dataclass + _prepare_chat_stream_context() factory so
# both paths call the same setup code.
# https://linear.app/onyx-app/issue/ENG-3888
"""Run 2-3 LLMs in parallel and yield their packets tagged with model_index.
Resource management:
- Each model thread gets its OWN db_session (SQLAlchemy sessions are not thread-safe)
- The caller's db_session is used only for setup (before threads launch) and
completion callbacks (after threads finish)
- ThreadPoolExecutor is bounded to len(overrides) workers
- All threads are joined in the finally block regardless of success/failure
- Queue-based merging avoids busy-waiting
"""
n_models = len(llm_overrides)
if n_models < 2 or n_models > 3:
raise ValueError(f"Multi-model requires 2-3 overrides, got {n_models}")
if new_msg_req.deep_research:
raise ValueError("Multi-model is not supported with deep research")
tenant_id = get_current_tenant_id()
cache: CacheBackend | None = None
chat_session: ChatSession | None = None
user_id = user.id
if user.is_anonymous:
llm_user_identifier = "anonymous_user"
else:
llm_user_identifier = user.email or str(user_id)
try:
# ── Session setup (same as single-model path) ──────────────────
if not new_msg_req.chat_session_id:
if not new_msg_req.chat_session_info:
raise RuntimeError(
"Must specify a chat session id or chat session info"
)
chat_session = create_chat_session_from_request(
chat_session_request=new_msg_req.chat_session_info,
user_id=user_id,
db_session=db_session,
)
yield CreateChatSessionID(chat_session_id=chat_session.id)
else:
chat_session = get_chat_session_by_id(
chat_session_id=new_msg_req.chat_session_id,
user_id=user_id,
db_session=db_session,
)
persona = chat_session.persona
message_text = new_msg_req.message
# ── Build N LLM instances and validate costs ───────────────────
llms: list[LLM] = []
model_display_names: list[str] = []
for override in llm_overrides:
llm = get_llm_for_persona(
persona=persona,
user=user,
llm_override=override,
additional_headers=litellm_additional_headers,
)
check_llm_cost_limit_for_provider(
db_session=db_session,
tenant_id=tenant_id,
llm_provider_api_key=llm.config.api_key,
)
llms.append(llm)
model_display_names.append(_build_model_display_name(override))
# Use first LLM for token counting (context window is checked per-model
# but token counting is model-agnostic enough for setup purposes)
token_counter = get_llm_token_counter(llms[0])
verify_user_files(
user_files=new_msg_req.file_descriptors,
user_id=user_id,
db_session=db_session,
project_id=chat_session.project_id,
)
# ── Chat history chain (shared across all models) ──────────────
chat_history = create_chat_history_chain(
chat_session_id=chat_session.id, db_session=db_session
)
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
)
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
parent_message = chat_history[-1] if chat_history else root_message
elif (
new_msg_req.parent_message_id is None
or new_msg_req.parent_message_id == root_message.id
):
parent_message = root_message
chat_history = []
else:
parent_message = None
for i in range(len(chat_history) - 1, -1, -1):
if chat_history[i].id == new_msg_req.parent_message_id:
parent_message = chat_history[i]
chat_history = chat_history[: i + 1]
break
if parent_message is None:
raise ValueError(
"The new message sent is not on the latest mainline of messages"
)
if parent_message.message_type == MessageType.USER:
user_message = parent_message
else:
user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=parent_message,
message=message_text,
token_count=token_counter(message_text),
message_type=MessageType.USER,
files=new_msg_req.file_descriptors,
db_session=db_session,
commit=True,
)
chat_history.append(user_message)
available_files = _collect_available_file_ids(
chat_history=chat_history,
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
summary_message = find_summary_for_branch(db_session, chat_history)
summarized_file_metadata: dict[str, FileToolMetadata] = {}
if summary_message and summary_message.last_summarized_message_id:
cutoff_id = summary_message.last_summarized_message_id
for msg in chat_history:
if msg.id > cutoff_id or not msg.files:
continue
for fd in msg.files:
file_id = fd.get("id")
if not file_id:
continue
summarized_file_metadata[file_id] = FileToolMetadata(
file_id=file_id,
filename=fd.get("name") or "unknown",
approx_char_count=0,
)
chat_history = [m for m in chat_history if m.id > cutoff_id]
user_memory_context = get_memories(user, db_session)
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
prompt_memory_context = (
user_memory_context
if user.use_memories
else user_memory_context.without_memories()
)
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
custom_agent_prompt or ""
)
reserved_token_count = calculate_reserved_tokens(
db_session=db_session,
persona_system_prompt=max_reserved_system_prompt_tokens_str,
token_counter=token_counter,
files=new_msg_req.file_descriptors,
user_memory_context=prompt_memory_context,
)
context_user_files = resolve_context_user_files(
persona=persona,
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
# Use the smallest context window across all models for safety
min_context_window = min(llm.config.max_input_tokens for llm in llms)
extracted_context_files = extract_context_files(
user_files=context_user_files,
llm_max_context_window=min_context_window,
reserved_token_count=reserved_token_count,
db_session=db_session,
)
search_params = determine_search_params(
persona_id=persona.id,
project_id=chat_session.project_id,
extracted_context_files=extracted_context_files,
)
if persona.user_files:
existing = set(available_files.user_file_ids)
for uf in persona.user_files:
if uf.id not in existing:
available_files.user_file_ids.append(uf.id)
all_tools = get_tools(db_session)
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
search_tool_id = next(
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
None,
)
forced_tool_id = new_msg_req.forced_tool_id
if (
search_params.search_usage == SearchToolUsage.DISABLED
and forced_tool_id is not None
and search_tool_id is not None
and forced_tool_id == search_tool_id
):
forced_tool_id = None
files = load_all_chat_files(chat_history, db_session)
chat_files_for_tools = _convert_loaded_files_to_chat_files(files)
# ── Reserve N assistant message IDs ────────────────────────────
reserved_messages = reserve_multi_model_message_ids(
db_session=db_session,
chat_session_id=chat_session.id,
parent_message_id=user_message.id,
model_display_names=model_display_names,
)
yield MultiModelMessageResponseIDInfo(
user_message_id=user_message.id,
responses=[
ModelResponseSlot(message_id=m.id, model_name=name)
for m, name in zip(reserved_messages, model_display_names)
],
)
has_file_reader_tool = any(
tool.in_code_tool_id == "file_reader" for tool in all_tools
)
chat_history_result = convert_chat_history(
chat_history=chat_history,
files=files,
context_image_files=extracted_context_files.image_files,
additional_context=new_msg_req.additional_context,
token_counter=token_counter,
tool_id_to_name_map=tool_id_to_name_map,
)
simple_chat_history = chat_history_result.simple_messages
all_injected_file_metadata: dict[str, FileToolMetadata] = (
chat_history_result.all_injected_file_metadata
if has_file_reader_tool
else {}
)
if summarized_file_metadata:
for fid, meta in summarized_file_metadata.items():
all_injected_file_metadata.setdefault(fid, meta)
if summary_message is not None:
summary_simple = ChatMessageSimple(
message=summary_message.message,
token_count=summary_message.token_count,
message_type=MessageType.ASSISTANT,
)
simple_chat_history.insert(0, summary_simple)
# ── Stop signal and processing status ──────────────────────────
cache = get_cache_backend()
reset_cancel_status(chat_session.id, cache)
def check_is_connected() -> bool:
return check_stop_signal(chat_session.id, cache)
set_processing_status(
chat_session_id=chat_session.id,
cache=cache,
value=True,
)
# Release the main session's read transaction before the long stream
db_session.commit()
# ── Parallel model execution ───────────────────────────────────
# Each model thread writes tagged packets to this shared queue.
# Sentinel _MODEL_DONE signals that a thread finished.
merged_queue: queue.Queue[tuple[int, Packet | Exception | object]] = (
queue.Queue()
)
# Track per-model state containers for completion callbacks
state_containers: list[ChatStateContainer] = [
ChatStateContainer() for _ in range(n_models)
]
# Track which models completed successfully (for completion callbacks)
model_succeeded: list[bool] = [False] * n_models
user_identity = LLMUserIdentity(
user_id=llm_user_identifier,
session_id=str(chat_session.id),
)
def _run_model(model_idx: int) -> None:
"""Run a single model in a worker thread.
Uses _ModelIndexEmitter so packets flow directly to merged_queue
in real-time (not batched after completion). This enables true
parallel streaming where both models' tokens interleave on the wire.
DB access: tools may need a session during execution (e.g., search
tool). Each thread creates its own session via context manager.
"""
model_emitter = _ModelIndexEmitter(model_idx, merged_queue)
sc = state_containers[model_idx]
model_llm = llms[model_idx]
try:
# Each model thread gets its own DB session for tool execution.
# The session is scoped to the thread and closed when done.
with get_session_with_current_tenant() as thread_db_session:
# Construct tools per-thread with thread-local DB session
thread_tool_dict = construct_tools(
persona=persona,
db_session=thread_db_session,
emitter=model_emitter,
user=user,
llm=model_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=new_msg_req.internal_search_filters,
project_id_filter=search_params.project_id_filter,
persona_id_filter=search_params.persona_id_filter,
bypass_acl=False,
enable_slack_search=_should_enable_slack_search(
persona, new_msg_req.internal_search_filters
),
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session.id,
message_id=user_message.id,
additional_headers=custom_tool_additional_headers,
mcp_headers=mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
user_file_ids=available_files.user_file_ids,
chat_file_ids=available_files.chat_file_ids,
),
allowed_tool_ids=new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=search_params.search_usage,
)
model_tools: list[Tool] = []
for tool_list in thread_tool_dict.values():
model_tools.extend(tool_list)
# Run the LLM loop — this blocks until the model finishes.
# Packets flow to merged_queue in real-time via the emitter.
run_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(simple_chat_history),
tools=model_tools,
custom_agent_prompt=custom_agent_prompt,
context_files=extracted_context_files,
persona=persona,
user_memory_context=user_memory_context,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
db_session=thread_db_session,
forced_tool_id=forced_tool_id,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
chat_files=chat_files_for_tools,
include_citations=new_msg_req.include_citations,
all_injected_file_metadata=all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
model_succeeded[model_idx] = True
except Exception as e:
merged_queue.put((model_idx, e))
finally:
merged_queue.put((model_idx, _MODEL_DONE))
# Launch model threads via ThreadPoolExecutor (bounded, context-propagating)
executor = ThreadPoolExecutor(
max_workers=n_models,
thread_name_prefix="multi-model",
)
futures = []
try:
for i in range(n_models):
futures.append(executor.submit(_run_model, i))
# ── Main thread: merge and yield packets ───────────────────
models_remaining = n_models
while models_remaining > 0:
try:
model_idx, item = merged_queue.get(timeout=0.3)
except queue.Empty:
# Check cancellation during idle periods
if not check_is_connected():
yield Packet(
placement=Placement(turn_index=0),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
return
continue
else:
if item is _MODEL_DONE:
models_remaining -= 1
continue
if isinstance(item, Exception):
# Yield error as a tagged StreamingError packet.
# Do NOT decrement models_remaining here — the finally block
# in _run_model always posts _MODEL_DONE, which is the sole
# completion signal. Decrementing here too would double-count
# and cause the loop to exit early, silently dropping the
# surviving models' responses.
error_msg = str(item)
stack_trace = "".join(
traceback.format_exception(
type(item), item, item.__traceback__
)
)
# Redact API keys from error messages
model_llm = llms[model_idx]
if (
model_llm.config.api_key
and len(model_llm.config.api_key) > 2
):
error_msg = error_msg.replace(
model_llm.config.api_key, "[REDACTED_API_KEY]"
)
stack_trace = stack_trace.replace(
model_llm.config.api_key, "[REDACTED_API_KEY]"
)
yield StreamingError(
error=error_msg,
stack_trace=stack_trace,
error_code="MODEL_ERROR",
is_retryable=True,
details={
"model": model_llm.config.model_name,
"provider": model_llm.config.model_provider,
"model_index": model_idx,
},
)
continue
if isinstance(item, Packet):
# Packet is already tagged with model_index by _ModelIndexEmitter
yield item
# ── Completion: save each successful model's response ──────
# Run completion callbacks on the main thread using the main
# session. This is safe because all worker threads have exited
# by this point (merged_queue fully drained).
for i in range(n_models):
if not model_succeeded[i]:
continue
try:
llm_loop_completion_handle(
state_container=state_containers[i],
is_connected=check_is_connected,
db_session=db_session,
assistant_message=reserved_messages[i],
llm=llms[i],
reserved_tokens=reserved_token_count,
)
except Exception:
logger.exception(
f"Failed completion for model {i} "
f"({model_display_names[i]})"
)
yield Packet(
placement=Placement(turn_index=0),
obj=OverallStop(type="stop", stop_reason="complete"),
)
finally:
# Don't block on shutdown — futures making live LLM API calls
# cannot be cancelled once started, so wait=True would block
# the generator (and the HTTP response) until all calls finish.
# wait=False lets threads complete in the background.
executor.shutdown(wait=False)
except ValueError as e:
logger.exception("Failed to process multi-model chat message.")
yield StreamingError(
error=str(e),
error_code="VALIDATION_ERROR",
is_retryable=True,
)
db_session.rollback()
return
except Exception as e:
logger.exception(f"Failed multi-model chat: {e}")
stack_trace = traceback.format_exc()
yield StreamingError(
error=str(e),
stack_trace=stack_trace,
error_code="MULTI_MODEL_ERROR",
is_retryable=True,
)
db_session.rollback()
finally:
try:
if cache is not None and chat_session is not None:
set_processing_status(
chat_session_id=chat_session.id,
cache=cache,
value=False,
)
except Exception:
logger.exception("Error clearing processing status")
def llm_loop_completion_handle(
state_container: ChatStateContainer,
is_connected: Callable[[], bool],

View File

@@ -332,6 +332,10 @@ OPENSEARCH_INDEX_NUM_REPLICAS: int | None = (
if os.environ.get("OPENSEARCH_INDEX_NUM_REPLICAS", None) is not None
else None
)
ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH = (
os.environ.get("ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH", "").lower()
== "true"
)
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
# NOTE: this is used if and only if the vespa config server is accessible via a

View File

@@ -24,11 +24,11 @@ CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1)
LLM_SOCKET_READ_TIMEOUT = int(
os.environ.get("LLM_SOCKET_READ_TIMEOUT") or "60"
) # 60 seconds
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
# Weighting factor between vector and keyword Search; 1 for completely vector
# search, 0 for keyword. Enforces a valid range of [0, 1]. A supplied value from
# the env outside of this range will be clipped to the respective end of the
# range. Defaults to 0.5.
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5)))
HYBRID_ALPHA_KEYWORD = max(
0, min(1, float(os.environ.get("HYBRID_ALPHA_KEYWORD") or 0.4))
)
# Weighting factor between Title and Content of documents during search, 1 for completely
# Title based. Default heavily favors Content because Title is also included at the top of
# Content. This is to avoid cases where the Content is very relevant but it may not be clear

View File

@@ -123,7 +123,7 @@ class OnyxConfluence:
self.shared_base_kwargs: dict[str, str | int | bool] = {
"api_version": "cloud" if is_cloud else "latest",
"backoff_and_retry": True,
"backoff_and_retry": False,
"cloud": is_cloud,
}
if timeout:
@@ -456,7 +456,7 @@ class OnyxConfluence:
return attr(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
logger.warning(
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
)

View File

@@ -363,7 +363,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except requests.HTTPError as e:
delay_until = _handle_http_error(e, attempt)
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
logger.warning(
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
)
@@ -384,7 +384,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return cast(F, wrapped_call)
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
def _handle_http_error(e: requests.HTTPError, attempt: int, max_retries: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
@@ -408,6 +408,17 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
raise e
if e.response.status_code >= 500:
if attempt >= max_retries - 1:
raise e
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
logger.warning(
f"Server error {e.response.status_code}. "
f"Retrying in {delay} seconds (attempt {attempt + 1})..."
)
return math.ceil(time.monotonic() + delay)
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()

View File

@@ -53,7 +53,7 @@ class NotionPage(BaseModel):
id: str
created_time: str
last_edited_time: str
archived: bool
in_trash: bool
properties: dict[str, Any]
url: str
@@ -63,6 +63,13 @@ class NotionPage(BaseModel):
)
class NotionDataSource(BaseModel):
"""Represents a Notion Data Source within a database."""
id: str
name: str = ""
class NotionBlock(BaseModel):
"""Represents a Notion Block object"""
@@ -107,7 +114,7 @@ class NotionConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self.headers = {
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
"Notion-Version": "2026-03-11",
}
self.indexed_pages: set[str] = set()
self.root_page_id = root_page_id
@@ -127,6 +134,9 @@ class NotionConnector(LoadConnector, PollConnector):
# Maps child page IDs to their containing page ID (discovered in _read_blocks).
# Used to resolve block_id parent types to the actual containing page.
self._child_page_parent_map: dict[str, str] = {}
# Maps data_source_id -> database_id (populated in _read_pages_from_database).
# Used to resolve data_source_id parent types back to the database.
self._data_source_to_database_map: dict[str, str] = {}
@classmethod
@override
@@ -227,7 +237,11 @@ class NotionConnector(LoadConnector, PollConnector):
@retry(tries=3, delay=1, backoff=2)
def _fetch_database_as_page(self, database_id: str) -> NotionPage:
"""Attempt to fetch a database as a page."""
"""Attempt to fetch a database as a page.
Note: As of API 2025-09-03, database objects no longer include
`properties` (schema moved to individual data sources).
"""
logger.debug(f"Fetching database for ID '{database_id}' as a page")
database_url = f"https://api.notion.com/v1/databases/{database_id}"
res = rl_requests.get(
@@ -246,18 +260,52 @@ class NotionConnector(LoadConnector, PollConnector):
database_name[0].get("text", {}).get("content") if database_name else None
)
db_data.setdefault("properties", {})
return NotionPage(**db_data, database_name=database_name)
@retry(tries=3, delay=1, backoff=2)
def _fetch_database(
self, database_id: str, cursor: str | None = None
def _fetch_data_sources_for_database(
self, database_id: str
) -> list[NotionDataSource]:
"""Fetch the list of data sources for a database."""
logger.debug(f"Fetching data sources for database '{database_id}'")
res = rl_requests.get(
f"https://api.notion.com/v1/databases/{database_id}",
headers=self.headers,
timeout=_NOTION_CALL_TIMEOUT,
)
try:
res.raise_for_status()
except Exception as e:
if res.status_code in (403, 404):
logger.error(
f"Unable to access database with ID '{database_id}'. "
f"This is likely due to the database not being shared "
f"with the Onyx integration. Exact exception:\n{e}"
)
return []
logger.exception(f"Error fetching database - {res.json()}")
raise e
db_data = res.json()
data_sources = db_data.get("data_sources", [])
return [
NotionDataSource(id=ds["id"], name=ds.get("name", ""))
for ds in data_sources
if ds.get("id")
]
@retry(tries=3, delay=1, backoff=2)
def _fetch_data_source(
self, data_source_id: str, cursor: str | None = None
) -> dict[str, Any]:
"""Fetch a database from it's ID via the Notion API."""
logger.debug(f"Fetching database for ID '{database_id}'")
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
"""Query a data source via POST /v1/data_sources/{id}/query."""
logger.debug(f"Querying data source '{data_source_id}'")
url = f"https://api.notion.com/v1/data_sources/{data_source_id}/query"
body = None if not cursor else {"start_cursor": cursor}
res = rl_requests.post(
block_url,
url,
headers=self.headers,
json=body,
timeout=_NOTION_CALL_TIMEOUT,
@@ -265,25 +313,14 @@ class NotionConnector(LoadConnector, PollConnector):
try:
res.raise_for_status()
except Exception as e:
json_data = res.json()
code = json_data.get("code")
# Sep 3 2025 backend changed the error message for this case
# TODO: it is also now possible for there to be multiple data sources per database; at present we
# just don't handle that. We will need to upgrade the API to the current version + query the
# new data sources endpoint to handle that case correctly.
if code == "object_not_found" or (
code == "validation_error"
and "does not contain any data sources" in json_data.get("message", "")
):
# this happens when a database is not shared with the integration
# in this case, we should just ignore the database
if res.status_code in (403, 404):
logger.error(
f"Unable to access database with ID '{database_id}'. "
f"This is likely due to the database not being shared "
f"Unable to access data source with ID '{data_source_id}'. "
f"This is likely due to it not being shared "
f"with the Onyx integration. Exact exception:\n{e}"
)
return {"results": [], "next_cursor": None}
logger.exception(f"Error fetching database - {res.json()}")
logger.exception(f"Error querying data source - {res.json()}")
raise e
return res.json()
@@ -348,8 +385,9 @@ class NotionConnector(LoadConnector, PollConnector):
# Fallback to workspace if we don't know the parent
return self.workspace_id
elif parent_type == "data_source_id":
# Newer Notion API may use data_source_id for databases
return parent.get("database_id") or parent.get("data_source_id")
ds_id = parent.get("data_source_id")
if ds_id:
return self._data_source_to_database_map.get(ds_id, self.workspace_id)
elif parent_type in ["page_id", "database_id"]:
return parent.get(parent_type)
@@ -497,18 +535,32 @@ class NotionConnector(LoadConnector, PollConnector):
if db_node:
hierarchy_nodes.append(db_node)
cursor = None
while True:
data = self._fetch_database(database_id, cursor)
# Discover all data sources under this database, then query each one.
# Even legacy single-source databases have one entry in the array.
data_sources = self._fetch_data_sources_for_database(database_id)
if not data_sources:
logger.warning(
f"Database '{database_id}' returned zero data sources — "
f"no pages will be indexed from this database."
)
for ds in data_sources:
self._data_source_to_database_map[ds.id] = database_id
cursor = None
while True:
data = self._fetch_data_source(ds.id, cursor)
for result in data["results"]:
obj_id = result["id"]
obj_type = result["object"]
text = self._properties_to_str(result.get("properties", {}))
if text:
result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n"))
for result in data["results"]:
obj_id = result["id"]
obj_type = result["object"]
text = self._properties_to_str(result.get("properties", {}))
if text:
result_blocks.append(
NotionBlock(id=obj_id, text=text, prefix="\n")
)
if not self.recursive_index_enabled:
continue
if self.recursive_index_enabled:
if obj_type == "page":
logger.debug(
f"Found page with ID '{obj_id}' in database '{database_id}'"
@@ -518,7 +570,6 @@ class NotionConnector(LoadConnector, PollConnector):
logger.debug(
f"Found database with ID '{obj_id}' in database '{database_id}'"
)
# Get nested database name from properties if available
nested_db_title = result.get("title", [])
nested_db_name = None
if nested_db_title and len(nested_db_title) > 0:
@@ -533,10 +584,10 @@ class NotionConnector(LoadConnector, PollConnector):
result_pages.extend(nested_output.child_page_ids)
hierarchy_nodes.extend(nested_output.hierarchy_nodes)
if data["next_cursor"] is None:
break
if data["next_cursor"] is None:
break
cursor = data["next_cursor"]
cursor = data["next_cursor"]
return BlockReadOutput(
blocks=result_blocks,
@@ -807,36 +858,55 @@ class NotionConnector(LoadConnector, PollConnector):
def _yield_database_hierarchy_nodes(
self,
) -> Generator[HierarchyNode | Document, None, None]:
"""Search for all databases and yield hierarchy nodes for each.
"""Search for all data sources and yield hierarchy nodes for their parent databases.
This must be called BEFORE page indexing so that database hierarchy nodes
exist when pages inside databases reference them as parents.
With the new API, search returns data source objects instead of databases.
Multiple data sources can share the same parent database, so we use
database_id as the hierarchy node key and deduplicate via
_maybe_yield_hierarchy_node.
"""
query_dict: dict[str, Any] = {
"filter": {"property": "object", "value": "database"},
"filter": {"property": "object", "value": "data_source"},
"page_size": _NOTION_PAGE_SIZE,
}
pages_seen = 0
while pages_seen < _MAX_PAGES:
db_res = self._search_notion(query_dict)
for db in db_res.results:
db_id = db["id"]
# Extract title from the title array
title_arr = db.get("title", [])
db_name = None
if title_arr:
db_name = " ".join(
t.get("plain_text", "") for t in title_arr
).strip()
if not db_name:
for ds in db_res.results:
# Extract the parent database_id from the data source's parent
ds_parent = ds.get("parent", {})
db_id = ds_parent.get("database_id")
if not db_id:
continue
# Populate the mapping so _get_parent_raw_id can resolve later
ds_id = ds.get("id")
if not ds_id:
continue
self._data_source_to_database_map[ds_id] = db_id
# Fetch the database to get its actual name and parent
try:
db_page = self._fetch_database_as_page(db_id)
db_name = db_page.database_name or f"Database {db_id}"
parent_raw_id = self._get_parent_raw_id(db_page.parent)
db_url = (
db_page.url or f"https://notion.so/{db_id.replace('-', '')}"
)
except requests.exceptions.RequestException as e:
logger.warning(
f"Could not fetch database '{db_id}', "
f"defaulting to workspace root. Error: {e}"
)
db_name = f"Database {db_id}"
parent_raw_id = self.workspace_id
db_url = f"https://notion.so/{db_id.replace('-', '')}"
# Get parent using existing helper
parent_raw_id = self._get_parent_raw_id(db.get("parent"))
# Notion URLs omit dashes from UUIDs
db_url = db.get("url") or f"https://notion.so/{db_id.replace('-', '')}"
# _maybe_yield_hierarchy_node deduplicates by raw_node_id,
# so multiple data sources under one database produce one node.
node = self._maybe_yield_hierarchy_node(
raw_node_id=db_id,
raw_parent_id=parent_raw_id or self.workspace_id,

View File

@@ -1,5 +1,6 @@
import base64
import copy
import fnmatch
import html
import io
import os
@@ -84,6 +85,44 @@ SHARED_DOCUMENTS_MAP_REVERSE = {v: k for k, v in SHARED_DOCUMENTS_MAP.items()}
ASPX_EXTENSION = ".aspx"
def _is_site_excluded(site_url: str, excluded_site_patterns: list[str]) -> bool:
"""Check if a site URL matches any of the exclusion glob patterns."""
for pattern in excluded_site_patterns:
if fnmatch.fnmatch(site_url, pattern) or fnmatch.fnmatch(
site_url.rstrip("/"), pattern.rstrip("/")
):
return True
return False
def _is_path_excluded(item_path: str, excluded_path_patterns: list[str]) -> bool:
"""Check if a drive item path matches any of the exclusion glob patterns.
item_path is the relative path within a drive, e.g. "Engineering/API/report.docx".
Matches are attempted against the full path and the filename alone so that
patterns like "*.tmp" match files at any depth.
"""
filename = item_path.rsplit("/", 1)[-1] if "/" in item_path else item_path
for pattern in excluded_path_patterns:
if fnmatch.fnmatch(item_path, pattern) or fnmatch.fnmatch(filename, pattern):
return True
return False
def _build_item_relative_path(parent_reference_path: str | None, item_name: str) -> str:
"""Build the relative path of a drive item from its parentReference.path and name.
Example: parentReference.path="/drives/abc/root:/Eng/API", name="report.docx"
=> "Eng/API/report.docx"
"""
if parent_reference_path and "root:/" in parent_reference_path:
folder = unquote(parent_reference_path.split("root:/", 1)[1])
if folder:
return f"{folder}/{item_name}"
return item_name
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
DEFAULT_SHAREPOINT_DOMAIN_SUFFIX = "sharepoint.com"
@@ -478,6 +517,7 @@ def _convert_driveitem_to_document_with_permissions(
include_permissions: bool = False,
parent_hierarchy_raw_node_id: str | None = None,
access_token: str | None = None,
treat_sharing_link_as_public: bool = False,
) -> Document | ConnectorFailure | None:
if not driveitem.name or not driveitem.id:
@@ -610,6 +650,7 @@ def _convert_driveitem_to_document_with_permissions(
drive_item=sdk_item,
drive_name=drive_name,
add_prefix=True,
treat_sharing_link_as_public=treat_sharing_link_as_public,
)
else:
external_access = ExternalAccess.empty()
@@ -644,6 +685,7 @@ def _convert_sitepage_to_document(
graph_client: GraphClient,
include_permissions: bool = False,
parent_hierarchy_raw_node_id: str | None = None,
treat_sharing_link_as_public: bool = False,
) -> Document:
"""Convert a SharePoint site page to a Document object."""
# Extract text content from the site page
@@ -773,6 +815,7 @@ def _convert_sitepage_to_document(
graph_client=graph_client,
site_page=site_page,
add_prefix=True,
treat_sharing_link_as_public=treat_sharing_link_as_public,
)
else:
external_access = ExternalAccess.empty()
@@ -803,6 +846,7 @@ def _convert_driveitem_to_slim_document(
ctx: ClientContext,
graph_client: GraphClient,
parent_hierarchy_raw_node_id: str | None = None,
treat_sharing_link_as_public: bool = False,
) -> SlimDocument:
if driveitem.id is None:
raise ValueError("DriveItem ID is required")
@@ -813,6 +857,7 @@ def _convert_driveitem_to_slim_document(
graph_client=graph_client,
drive_item=sdk_item,
drive_name=drive_name,
treat_sharing_link_as_public=treat_sharing_link_as_public,
)
return SlimDocument(
@@ -827,6 +872,7 @@ def _convert_sitepage_to_slim_document(
ctx: ClientContext | None,
graph_client: GraphClient,
parent_hierarchy_raw_node_id: str | None = None,
treat_sharing_link_as_public: bool = False,
) -> SlimDocument:
"""Convert a SharePoint site page to a SlimDocument object."""
if site_page.get("id") is None:
@@ -836,6 +882,7 @@ def _convert_sitepage_to_slim_document(
ctx=ctx,
graph_client=graph_client,
site_page=site_page,
treat_sharing_link_as_public=treat_sharing_link_as_public,
)
id = site_page.get("id")
if id is None:
@@ -855,14 +902,20 @@ class SharepointConnector(
self,
batch_size: int = INDEX_BATCH_SIZE,
sites: list[str] = [],
excluded_sites: list[str] = [],
excluded_paths: list[str] = [],
include_site_pages: bool = True,
include_site_documents: bool = True,
treat_sharing_link_as_public: bool = False,
authority_host: str = DEFAULT_AUTHORITY_HOST,
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
sharepoint_domain_suffix: str = DEFAULT_SHAREPOINT_DOMAIN_SUFFIX,
) -> None:
self.batch_size = batch_size
self.sites = list(sites)
self.excluded_sites = [s for p in excluded_sites if (s := p.strip())]
self.excluded_paths = [s for p in excluded_paths if (s := p.strip())]
self.treat_sharing_link_as_public = treat_sharing_link_as_public
self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info(
sites
)
@@ -1233,6 +1286,29 @@ class SharepointConnector(
break
sites = sites._get_next().execute_query()
def _is_driveitem_excluded(self, driveitem: DriveItemData) -> bool:
"""Check if a drive item should be excluded based on excluded_paths patterns."""
if not self.excluded_paths:
return False
relative_path = _build_item_relative_path(
driveitem.parent_reference_path, driveitem.name
)
return _is_path_excluded(relative_path, self.excluded_paths)
def _filter_excluded_sites(
self, site_descriptors: list[SiteDescriptor]
) -> list[SiteDescriptor]:
"""Remove sites matching any excluded_sites glob pattern."""
if not self.excluded_sites:
return site_descriptors
result = []
for sd in site_descriptors:
if _is_site_excluded(sd.url, self.excluded_sites):
logger.info(f"Excluding site by denylist: {sd.url}")
continue
result.append(sd)
return result
def fetch_sites(self) -> list[SiteDescriptor]:
sites = self.graph_client.sites.get_all_sites().execute_query()
@@ -1249,7 +1325,7 @@ class SharepointConnector(
for site in self._handle_paginated_sites(sites)
if "-my.sharepoint" not in site.web_url
]
return site_descriptors
return self._filter_excluded_sites(site_descriptors)
def _fetch_site_pages(
self,
@@ -1690,7 +1766,9 @@ class SharepointConnector(
checkpoint.seen_document_ids.clear()
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
site_descriptors = self.site_descriptors or self.fetch_sites()
site_descriptors = self._filter_excluded_sites(
self.site_descriptors or self.fetch_sites()
)
# Create a temporary checkpoint for hierarchy node tracking
temp_checkpoint = SharepointConnectorCheckpoint(has_more=True)
@@ -1710,6 +1788,10 @@ class SharepointConnector(
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
site_descriptor=site_descriptor
):
if self._is_driveitem_excluded(driveitem):
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
continue
if drive_web_url:
doc_batch.extend(
self._yield_drive_hierarchy_node(
@@ -1747,6 +1829,7 @@ class SharepointConnector(
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=parent_hierarchy_url,
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
)
)
except Exception as e:
@@ -1770,6 +1853,7 @@ class SharepointConnector(
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=site_descriptor.url,
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
)
)
if len(doc_batch) >= SLIM_BATCH_SIZE:
@@ -2043,7 +2127,9 @@ class SharepointConnector(
and not checkpoint.process_site_pages
):
logger.info("Initializing SharePoint sites for processing")
site_descs = self.site_descriptors or self.fetch_sites()
site_descs = self._filter_excluded_sites(
self.site_descriptors or self.fetch_sites()
)
checkpoint.cached_site_descriptors = deque(site_descs)
if not checkpoint.cached_site_descriptors:
@@ -2264,6 +2350,10 @@ class SharepointConnector(
for driveitem in driveitems:
item_count += 1
if self._is_driveitem_excluded(driveitem):
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
continue
if driveitem.id and driveitem.id in checkpoint.seen_document_ids:
logger.debug(
f"Skipping duplicate document {driveitem.id} ({driveitem.name})"
@@ -2318,6 +2408,7 @@ class SharepointConnector(
parent_hierarchy_raw_node_id=parent_hierarchy_url,
graph_api_base=self.graph_api_base,
access_token=access_token,
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
)
if isinstance(doc_or_failure, Document):
@@ -2398,6 +2489,7 @@ class SharepointConnector(
include_permissions=include_permissions,
# Site pages have the site as their parent
parent_hierarchy_raw_node_id=site_descriptor.url,
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
)
)
logger.info(

View File

@@ -17,6 +17,7 @@ def get_sharepoint_external_access(
drive_name: str | None = None,
site_page: dict[str, Any] | None = None,
add_prefix: bool = False,
treat_sharing_link_as_public: bool = False,
) -> ExternalAccess:
if drive_item and drive_item.id is None:
raise ValueError("DriveItem ID is required")
@@ -34,7 +35,13 @@ def get_sharepoint_external_access(
)
external_access = get_external_access_func(
ctx, graph_client, drive_name, drive_item, site_page, add_prefix
ctx,
graph_client,
drive_name,
drive_item,
site_page,
add_prefix,
treat_sharing_link_as_public,
)
return external_access

View File

@@ -401,3 +401,16 @@ class SavedSearchDocWithContent(SavedSearchDoc):
section in addition to the match_highlights."""
content: str
class PersonaSearchInfo(BaseModel):
"""Snapshot of persona data needed by the search pipeline.
Extracted from the ORM Persona before the DB session is released so that
SearchTool and search_pipeline never lazy-load relationships post-commit.
"""
document_set_names: list[str]
search_start_date: datetime | None
attached_document_ids: list[str]
hierarchy_node_ids: list[int]

View File

@@ -9,12 +9,12 @@ from onyx.context.search.models import ChunkSearchRequest
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import PersonaSearchInfo
from onyx.context.search.preprocessing.access_filters import (
build_access_filters_for_user,
)
from onyx.context.search.retrieval.search_runner import search_chunks
from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
@@ -247,8 +247,8 @@ def search_pipeline(
document_index: DocumentIndex,
# Used for ACLs and federated search, anonymous users only see public docs
user: User,
# Used for default filters and settings
persona: Persona | None,
# Pre-extracted persona search configuration (None when no persona)
persona_search_info: PersonaSearchInfo | None,
db_session: Session | None = None,
auto_detect_filters: bool = False,
llm: LLM | None = None,
@@ -263,24 +263,18 @@ def search_pipeline(
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
) -> list[InferenceChunk]:
persona_document_sets: list[str] | None = (
[persona_document_set.name for persona_document_set in persona.document_sets]
if persona
else None
persona_search_info.document_set_names if persona_search_info else None
)
persona_time_cutoff: datetime | None = (
persona.search_start_date if persona else None
persona_search_info.search_start_date if persona_search_info else None
)
# Extract assistant knowledge filters from persona
attached_document_ids: list[str] | None = (
[doc.id for doc in persona.attached_documents]
if persona and persona.attached_documents
persona_search_info.attached_document_ids or None
if persona_search_info
else None
)
hierarchy_node_ids: list[int] | None = (
[node.id for node in persona.hierarchy_nodes]
if persona and persona.hierarchy_nodes
else None
persona_search_info.hierarchy_node_ids or None if persona_search_info else None
)
filters = _build_index_filters(

View File

@@ -14,6 +14,10 @@ from onyx.context.search.utils import get_query_embedding
from onyx.context.search.utils import inference_section_from_chunks
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.interfaces_new import DocumentIndex as NewDocumentIndex
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchOldDocumentIndex,
)
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.federated_connectors.federated_retrieval import (
get_federated_retrieval_functions,
@@ -49,7 +53,7 @@ def combine_retrieval_results(
return sorted_chunks
def _embed_and_search(
def _embed_and_hybrid_search(
query_request: ChunkIndexRequest,
document_index: DocumentIndex,
db_session: Session | None = None,
@@ -81,6 +85,17 @@ def _embed_and_search(
return top_chunks
def _keyword_search(
query_request: ChunkIndexRequest,
document_index: NewDocumentIndex,
) -> list[InferenceChunk]:
return document_index.keyword_retrieval(
query=query_request.query,
filters=query_request.filters,
num_to_retrieve=query_request.limit or NUM_RETURNED_HITS,
)
def search_chunks(
query_request: ChunkIndexRequest,
user_id: UUID | None,
@@ -128,21 +143,38 @@ def search_chunks(
)
if normal_search_enabled:
run_queries.append(
(
_embed_and_search,
(query_request, document_index, db_session, embedding_model),
if (
query_request.hybrid_alpha is not None
and query_request.hybrid_alpha == 0.0
and isinstance(document_index, OpenSearchOldDocumentIndex)
):
# If hybrid alpha is explicitly set to keyword only, do pure keyword
# search without generating an embedding. This is currently only
# supported with OpenSearchDocumentIndex.
opensearch_new_document_index: NewDocumentIndex = document_index._real_index
run_queries.append(
(
lambda: _keyword_search(
query_request, opensearch_new_document_index
),
(),
)
)
else:
run_queries.append(
(
_embed_and_hybrid_search,
(query_request, document_index, db_session, embedding_model),
)
)
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
top_chunks = combine_retrieval_results(parallel_search_results)
if not top_chunks:
logger.debug(
f"Hybrid search returned no results for query: {query_request.query}with filters: {query_request.filters}"
f"Search returned no results for query: {query_request.query} with filters: {query_request.filters}."
)
return []
return top_chunks

View File

@@ -64,6 +64,9 @@ def get_chat_session_by_id(
joinedload(ChatSession.persona).options(
selectinload(Persona.tools),
selectinload(Persona.user_files),
selectinload(Persona.document_sets),
selectinload(Persona.attached_documents),
selectinload(Persona.hierarchy_nodes),
),
joinedload(ChatSession.project),
)
@@ -614,6 +617,80 @@ def reserve_message_id(
return empty_message
def reserve_multi_model_message_ids(
db_session: Session,
chat_session_id: UUID,
parent_message_id: int,
model_display_names: list[str],
) -> list[ChatMessage]:
"""Reserve N assistant message placeholders for multi-model parallel streaming.
All messages share the same parent (the user message). The parent's
latest_child_message_id points to the LAST reserved message so that the
default history-chain walker picks it up.
"""
reserved: list[ChatMessage] = []
for display_name in model_display_names:
msg = ChatMessage(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
latest_child_message_id=None,
message="Response was terminated prior to completion, try regenerating.",
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
message_type=MessageType.ASSISTANT,
model_display_name=display_name,
)
db_session.add(msg)
reserved.append(msg)
# Flush to assign IDs without committing yet
db_session.flush()
# Point parent's latest_child to the last reserved message
parent = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == parent_message_id)
.first()
)
if parent:
parent.latest_child_message_id = reserved[-1].id
db_session.commit()
return reserved
def set_preferred_response(
db_session: Session,
user_message_id: int,
preferred_assistant_message_id: int,
) -> None:
"""Set the preferred assistant response for a multi-model user message.
Validates that the user message is a USER type and that the preferred
assistant message is a direct child of that user message.
"""
user_msg = db_session.get(ChatMessage, user_message_id)
if user_msg is None:
raise ValueError(f"User message {user_message_id} not found")
if user_msg.message_type != MessageType.USER:
raise ValueError(f"Message {user_message_id} is not a user message")
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
if assistant_msg is None:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} not found"
)
if assistant_msg.parent_message_id != user_message_id:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} is not a child "
f"of user message {user_message_id}"
)
user_msg.preferred_response_id = preferred_assistant_message_id
user_msg.latest_child_message_id = preferred_assistant_message_id
db_session.commit()
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
@@ -836,6 +913,8 @@ def translate_db_message_to_chat_message_detail(
error=chat_message.error,
current_feedback=current_feedback,
processing_duration_seconds=chat_message.processing_duration_seconds,
preferred_response_id=chat_message.preferred_response_id,
model_display_name=chat_message.model_display_name,
)
return chat_msg_detail

View File

@@ -750,3 +750,31 @@ def resync_cc_pair(
)
db_session.commit()
# ── Metrics query helpers ──────────────────────────────────────────────
def get_connector_health_for_metrics(
db_session: Session,
) -> list: # Returns list of Row tuples
"""Return connector health data for Prometheus metrics.
Each row is (cc_pair_id, status, in_repeated_error_state,
last_successful_index_time, name, source).
"""
return (
db_session.query(
ConnectorCredentialPair.id,
ConnectorCredentialPair.status,
ConnectorCredentialPair.in_repeated_error_state,
ConnectorCredentialPair.last_successful_index_time,
ConnectorCredentialPair.name,
Connector.source,
)
.join(
Connector,
ConnectorCredentialPair.connector_id == Connector.id,
)
.all()
)

View File

@@ -1,4 +1,31 @@
from __future__ import annotations
from enum import Enum as PyEnum
from typing import ClassVar
class AccountType(str, PyEnum):
"""
What kind of account this is — determines whether the user
enters the group-based permission system.
STANDARD + SERVICE_ACCOUNT → participate in group system
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
"""
STANDARD = "standard"
BOT = "bot"
EXT_PERM_USER = "ext_perm_user"
SERVICE_ACCOUNT = "service_account"
ANONYMOUS = "anonymous"
class GrantSource(str, PyEnum):
"""How a permission grant was created."""
USER = "user"
SCIM = "scim"
SYSTEM = "system"
class IndexingStatus(str, PyEnum):
@@ -314,3 +341,54 @@ class HookPoint(str, PyEnum):
class HookFailStrategy(str, PyEnum):
HARD = "hard" # exception propagates, pipeline aborts
SOFT = "soft" # log error, return original input, pipeline continues
class Permission(str, PyEnum):
"""
Permission tokens for group-based authorization.
19 tokens total. full_admin_panel_access is an override —
if present, any permission check passes.
"""
# Basic (auto-granted to every new group)
BASIC_ACCESS = "basic"
# Read tokens — implied only, never granted directly
READ_CONNECTORS = "read:connectors"
READ_DOCUMENT_SETS = "read:document_sets"
READ_AGENTS = "read:agents"
READ_USERS = "read:users"
# Add / Manage pairs
ADD_AGENTS = "add:agents"
MANAGE_AGENTS = "manage:agents"
MANAGE_DOCUMENT_SETS = "manage:document_sets"
ADD_CONNECTORS = "add:connectors"
MANAGE_CONNECTORS = "manage:connectors"
MANAGE_LLMS = "manage:llms"
# Toggle tokens
READ_AGENT_ANALYTICS = "read:agent_analytics"
MANAGE_ACTIONS = "manage:actions"
READ_QUERY_HISTORY = "read:query_history"
MANAGE_USER_GROUPS = "manage:user_groups"
CREATE_USER_API_KEYS = "create:user_api_keys"
CREATE_SERVICE_ACCOUNT_API_KEYS = "create:service_account_api_keys"
CREATE_SLACK_DISCORD_BOTS = "create:slack_discord_bots"
# Override — any permission check passes
FULL_ADMIN_PANEL_ACCESS = "admin"
# Permissions that are implied by other grants and must never be stored
# directly in the permission_grant table.
IMPLIED: ClassVar[frozenset[Permission]]
Permission.IMPLIED = frozenset(
{
Permission.READ_CONNECTORS,
Permission.READ_DOCUMENT_SETS,
Permission.READ_AGENTS,
Permission.READ_USERS,
}
)

View File

@@ -75,6 +75,7 @@ def create_hook__no_commit(
fail_strategy: HookFailStrategy,
timeout_seconds: float,
is_active: bool = False,
is_reachable: bool | None = None,
creator_id: UUID | None = None,
) -> Hook:
"""Create a new hook for the given hook point.
@@ -100,6 +101,7 @@ def create_hook__no_commit(
fail_strategy=fail_strategy,
timeout_seconds=timeout_seconds,
is_active=is_active,
is_reachable=is_reachable,
creator_id=creator_id,
)
# Use a savepoint so that a failed insert only rolls back this operation,

View File

@@ -2,6 +2,8 @@ from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import NamedTuple
from typing import TYPE_CHECKING
from typing import TypeVarTuple
from sqlalchemy import and_
@@ -28,6 +30,9 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
if TYPE_CHECKING:
from onyx.configs.constants import DocumentSource
# from sqlalchemy.sql.selectable import Select
# Comment out unused imports that cause mypy errors
@@ -972,3 +977,106 @@ def get_index_attempt_errors_for_cc_pair(
stmt = stmt.offset(page * page_size).limit(page_size)
return list(db_session.scalars(stmt).all())
# ── Metrics query helpers ──────────────────────────────────────────────
class ActiveIndexAttemptMetric(NamedTuple):
"""Row returned by get_active_index_attempts_for_metrics."""
status: IndexingStatus
source: "DocumentSource"
cc_pair_id: int
cc_pair_name: str | None
attempt_count: int
def get_active_index_attempts_for_metrics(
db_session: Session,
) -> list[ActiveIndexAttemptMetric]:
"""Return non-terminal index attempts grouped by status, source, and connector.
Each row is (status, source, cc_pair_id, cc_pair_name, attempt_count).
"""
from onyx.db.models import Connector
terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
rows = (
db_session.query(
IndexAttempt.status,
Connector.source,
ConnectorCredentialPair.id,
ConnectorCredentialPair.name,
func.count(),
)
.join(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.join(
Connector,
ConnectorCredentialPair.connector_id == Connector.id,
)
.filter(IndexAttempt.status.notin_(terminal_statuses))
.group_by(
IndexAttempt.status,
Connector.source,
ConnectorCredentialPair.id,
ConnectorCredentialPair.name,
)
.all()
)
return [ActiveIndexAttemptMetric(*row) for row in rows]
def get_failed_attempt_counts_by_cc_pair(
db_session: Session,
since: datetime | None = None,
) -> dict[int, int]:
"""Return {cc_pair_id: failed_attempt_count} for all connectors.
When ``since`` is provided, only attempts created after that timestamp
are counted. Defaults to the last 90 days to avoid unbounded historical
aggregation.
"""
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=90)
rows = (
db_session.query(
IndexAttempt.connector_credential_pair_id,
func.count(),
)
.filter(IndexAttempt.status == IndexingStatus.FAILED)
.filter(IndexAttempt.time_created >= since)
.group_by(IndexAttempt.connector_credential_pair_id)
.all()
)
return {cc_id: count for cc_id, count in rows}
def get_docs_indexed_by_cc_pair(
db_session: Session,
since: datetime | None = None,
) -> dict[int, int]:
"""Return {cc_pair_id: total_new_docs_indexed} across successful attempts.
Only counts attempts with status SUCCESS to avoid inflating counts with
partial results from failed attempts. When ``since`` is provided, only
attempts created after that timestamp are included.
"""
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=90)
query = (
db_session.query(
IndexAttempt.connector_credential_pair_id,
func.sum(func.coalesce(IndexAttempt.new_docs_indexed, 0)),
)
.filter(IndexAttempt.status == IndexingStatus.SUCCESS)
.filter(IndexAttempt.time_created >= since)
.group_by(IndexAttempt.connector_credential_pair_id)
)
rows = query.all()
return {cc_id: int(total or 0) for cc_id, total in rows}

View File

@@ -48,6 +48,7 @@ from sqlalchemy.types import LargeBinary
from sqlalchemy.types import TypeDecorator
from sqlalchemy import PrimaryKeyConstraint
from onyx.db.enums import AccountType
from onyx.auth.schemas import UserRole
from onyx.configs.constants import (
ANONYMOUS_USER_UUID,
@@ -78,6 +79,8 @@ from onyx.db.enums import (
MCPAuthenticationPerformer,
MCPTransport,
MCPServerStatus,
Permission,
GrantSource,
LLMModelFlowType,
ThemePreference,
DefaultAppMode,
@@ -302,6 +305,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
)
account_type: Mapped[AccountType | None] = mapped_column(
Enum(AccountType, native_enum=False), nullable=True
)
"""
Preferences probably should be in a separate table at some point, but for now
@@ -2645,6 +2651,15 @@ class ChatMessage(Base):
nullable=True,
)
# For multi-model turns: the user message points to which assistant response
# was selected as the preferred one to continue the conversation with.
preferred_response_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True
)
# The display name of the model that generated this assistant message
model_display_name: Mapped[str | None] = mapped_column(String, nullable=True)
# What does this message contain
reasoning_tokens: Mapped[str | None] = mapped_column(Text, nullable=True)
message: Mapped[str] = mapped_column(Text)
@@ -2712,6 +2727,12 @@ class ChatMessage(Base):
remote_side="ChatMessage.id",
)
preferred_response: Mapped["ChatMessage | None"] = relationship(
"ChatMessage",
foreign_keys=[preferred_response_id],
remote_side="ChatMessage.id",
)
# Chat messages only need to know their immediate tool call children
# If there are nested tool calls, they are stored in the tool_call_children relationship.
tool_calls: Mapped[list["ToolCall"] | None] = relationship(
@@ -3971,6 +3992,8 @@ class SamlAccount(Base):
class User__UserGroup(Base):
__tablename__ = "user__user_group"
__table_args__ = (Index("ix_user__user_group_user_id", "user_id"),)
is_curator: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
user_group_id: Mapped[int] = mapped_column(
@@ -3981,6 +4004,48 @@ class User__UserGroup(Base):
)
class PermissionGrant(Base):
__tablename__ = "permission_grant"
__table_args__ = (
UniqueConstraint(
"group_id", "permission", name="uq_permission_grant_group_permission"
),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
)
permission: Mapped[Permission] = mapped_column(
Enum(Permission, native_enum=False), nullable=False
)
grant_source: Mapped[GrantSource] = mapped_column(
Enum(GrantSource, native_enum=False), nullable=False
)
granted_by: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="SET NULL"), nullable=True
)
granted_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
is_deleted: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False, server_default=text("false")
)
group: Mapped["UserGroup"] = relationship(
"UserGroup", back_populates="permission_grants"
)
@validates("permission")
def _validate_permission(self, _key: str, value: Permission) -> Permission:
if value in Permission.IMPLIED:
raise ValueError(
f"{value!r} is an implied permission and cannot be granted directly"
)
return value
class UserGroup__ConnectorCredentialPair(Base):
__tablename__ = "user_group__connector_credential_pair"
@@ -4075,6 +4140,8 @@ class UserGroup(Base):
is_up_for_deletion: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
# whether this is a default group (e.g. "Basic", "Admins") that cannot be deleted
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
# Last time a user updated this user group
time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column(
@@ -4118,6 +4185,9 @@ class UserGroup(Base):
accessible_mcp_servers: Mapped[list["MCPServer"]] = relationship(
"MCPServer", secondary="mcp_server__user_group", back_populates="user_groups"
)
permission_grants: Mapped[list["PermissionGrant"]] = relationship(
"PermissionGrant", back_populates="group", cascade="all, delete-orphan"
)
"""Tables related to Token Rate Limiting

View File

@@ -50,8 +50,18 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
def get_default_behavior_persona(db_session: Session) -> Persona | None:
def get_default_behavior_persona(
db_session: Session,
eager_load_for_tools: bool = False,
) -> Persona | None:
stmt = select(Persona).where(Persona.id == DEFAULT_PERSONA_ID)
if eager_load_for_tools:
stmt = stmt.options(
selectinload(Persona.tools),
selectinload(Persona.document_sets),
selectinload(Persona.attached_documents),
selectinload(Persona.hierarchy_nodes),
)
return db_session.scalars(stmt).first()

View File

@@ -381,6 +381,47 @@ class HybridCapable(abc.ABC):
"""
raise NotImplementedError
@abc.abstractmethod
def keyword_retrieval(
self,
query: str,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
"""Runs keyword-only search and returns a list of inference chunks.
Args:
query: User query.
filters: Filters for things like permissions, source type, time,
etc.
num_to_retrieve: Number of highest matching chunks to return.
Returns:
Score-ranked (highest first) list of highest matching chunks.
"""
raise NotImplementedError
@abc.abstractmethod
def semantic_retrieval(
self,
query_embedding: Embedding,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
"""Runs semantic-only search and returns a list of inference chunks.
Args:
query_embedding: Vector representation of the query. Must be of the
correct dimensionality for the primary index.
filters: Filters for things like permissions, source type, time,
etc.
num_to_retrieve: Number of highest matching chunks to return.
Returns:
Score-ranked (highest first) list of highest matching chunks.
"""
raise NotImplementedError
class RandomCapable(abc.ABC):
"""

View File

@@ -18,10 +18,13 @@ from onyx.configs.app_configs import OPENSEARCH_ADMIN_USERNAME
from onyx.configs.app_configs import OPENSEARCH_HOST
from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import OpenSearchSearchType
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.search import DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW
from onyx.server.metrics.opensearch_search import observe_opensearch_search
from onyx.server.metrics.opensearch_search import track_opensearch_search_in_progress
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
@@ -256,7 +259,6 @@ class OpenSearchClient(AbstractContextManager):
"""
return self._client.ping()
@log_function_time(print_only=True, debug_only=True)
def close(self) -> None:
"""Closes the client.
@@ -304,6 +306,7 @@ class OpenSearchIndexClient(OpenSearchClient):
verify_certs: bool = False,
ssl_show_warn: bool = False,
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
emit_metrics: bool = True,
):
super().__init__(
host=host,
@@ -315,6 +318,7 @@ class OpenSearchIndexClient(OpenSearchClient):
timeout=timeout,
)
self._index_name = index_name
self._emit_metrics = emit_metrics
logger.debug(
f"OpenSearch client created successfully for index {self._index_name}."
)
@@ -834,7 +838,10 @@ class OpenSearchIndexClient(OpenSearchClient):
@log_function_time(print_only=True, debug_only=True)
def search(
self, body: dict[str, Any], search_pipeline_id: str | None
self,
body: dict[str, Any],
search_pipeline_id: str | None,
search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN,
) -> list[SearchHit[DocumentChunkWithoutVectors]]:
"""Searches the index.
@@ -852,6 +859,8 @@ class OpenSearchIndexClient(OpenSearchClient):
documentation for more information on search request bodies.
search_pipeline_id: The ID of the search pipeline to use. If None,
the default search pipeline will be used.
search_type: Label for Prometheus metrics. Does not affect search
behavior.
Raises:
Exception: There was an error searching the index.
@@ -864,21 +873,27 @@ class OpenSearchIndexClient(OpenSearchClient):
)
result: dict[str, Any]
params = {"phase_took": "true"}
if search_pipeline_id:
result = self._client.search(
index=self._index_name,
search_pipeline=search_pipeline_id,
body=body,
params=params,
)
else:
result = self._client.search(
index=self._index_name, body=body, params=params
)
ctx = self._get_emit_metrics_context_manager(search_type)
t0 = time.perf_counter()
with ctx:
if search_pipeline_id:
result = self._client.search(
index=self._index_name,
search_pipeline=search_pipeline_id,
body=body,
params=params,
)
else:
result = self._client.search(
index=self._index_name, body=body, params=params
)
client_duration_s = time.perf_counter() - t0
hits, time_took, timed_out, phase_took, profile = (
self._get_hits_and_profile_from_search_result(result)
)
if self._emit_metrics:
observe_opensearch_search(search_type, client_duration_s, time_took)
self._log_search_result_perf(
time_took=time_took,
timed_out=timed_out,
@@ -914,7 +929,11 @@ class OpenSearchIndexClient(OpenSearchClient):
return search_hits
@log_function_time(print_only=True, debug_only=True)
def search_for_document_ids(self, body: dict[str, Any]) -> list[str]:
def search_for_document_ids(
self,
body: dict[str, Any],
search_type: OpenSearchSearchType = OpenSearchSearchType.DOCUMENT_IDS,
) -> list[str]:
"""Searches the index and returns only document chunk IDs.
In order to take advantage of the performance benefits of only returning
@@ -931,6 +950,8 @@ class OpenSearchIndexClient(OpenSearchClient):
documentation for more information on search request bodies.
TODO(andrei): Make this a more deep interface; callers shouldn't
need to know to set _source: False for example.
search_type: Label for Prometheus metrics. Does not affect search
behavior.
Raises:
Exception: There was an error searching the index.
@@ -948,13 +969,19 @@ class OpenSearchIndexClient(OpenSearchClient):
)
params = {"phase_took": "true"}
result: dict[str, Any] = self._client.search(
index=self._index_name, body=body, params=params
)
ctx = self._get_emit_metrics_context_manager(search_type)
t0 = time.perf_counter()
with ctx:
result: dict[str, Any] = self._client.search(
index=self._index_name, body=body, params=params
)
client_duration_s = time.perf_counter() - t0
hits, time_took, timed_out, phase_took, profile = (
self._get_hits_and_profile_from_search_result(result)
)
if self._emit_metrics:
observe_opensearch_search(search_type, client_duration_s, time_took)
self._log_search_result_perf(
time_took=time_took,
timed_out=timed_out,
@@ -1071,6 +1098,20 @@ class OpenSearchIndexClient(OpenSearchClient):
if raise_on_timeout:
raise RuntimeError(error_str)
def _get_emit_metrics_context_manager(
self, search_type: OpenSearchSearchType
) -> AbstractContextManager[None]:
"""
Returns a context manager that tracks in-flight OpenSearch searches via
a Gauge if emit_metrics is True, otherwise returns a null context
manager.
"""
return (
track_opensearch_search_in_progress(search_type)
if self._emit_metrics
else nullcontext()
)
def wait_for_opensearch_with_timeout(
wait_interval_s: int = 5,

View File

@@ -53,6 +53,18 @@ DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES = int(
EF_SEARCH = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
class OpenSearchSearchType(str, Enum):
"""Search type label used for Prometheus metrics."""
HYBRID = "hybrid"
KEYWORD = "keyword"
SEMANTIC = "semantic"
RANDOM = "random"
ID_RETRIEVAL = "id_retrieval"
DOCUMENT_IDS = "document_ids"
UNKNOWN = "unknown"
class HybridSearchSubqueryConfiguration(Enum):
TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD = 1
# Current default.

View File

@@ -43,6 +43,7 @@ from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import SearchHit
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
from onyx.document_index.opensearch.constants import OpenSearchSearchType
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
@@ -900,6 +901,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_hits = self._client.search(
body=query_body,
search_pipeline_id=None,
search_type=OpenSearchSearchType.ID_RETRIEVAL,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
@@ -923,6 +925,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
# TODO(andrei): There is some duplicated logic in this function with
# others in this file.
logger.debug(
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
@@ -948,6 +952,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=normalization_pipeline_name,
search_type=OpenSearchSearchType.HYBRID,
)
# Good place for a breakpoint to inspect the search hits if you have
@@ -970,6 +975,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
# TODO(andrei): There is some duplicated logic in this function with
# others in this file.
logger.debug(
f"[OpenSearchDocumentIndex] Keyword retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
@@ -989,6 +996,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
search_type=OpenSearchSearchType.KEYWORD,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
@@ -1009,6 +1017,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
# TODO(andrei): There is some duplicated logic in this function with
# others in this file.
logger.debug(
f"[OpenSearchDocumentIndex] Semantic retrieving {num_to_retrieve} chunks for index {self._index_name}."
)
@@ -1028,6 +1038,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
search_type=OpenSearchSearchType.SEMANTIC,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
@@ -1059,6 +1070,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
body=query_body,
search_pipeline_id=None,
search_type=OpenSearchSearchType.RANDOM,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(

View File

@@ -3,6 +3,8 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import TypeAlias
from typing import TypeVar
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
@@ -48,13 +50,21 @@ from onyx.document_index.opensearch.schema import TITLE_FIELD_NAME
from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
# Normalization pipelines combine document scores from multiple query clauses.
# The number and ordering of weights should match the query clauses. The values
# of the weights should sum to 1.
# See https://docs.opensearch.org/latest/query-dsl/term/terms/.
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY = 65_536
_T = TypeVar("_T")
TermsQuery: TypeAlias = dict[str, dict[str, list[_T]]]
TermQuery: TypeAlias = dict[str, dict[str, dict[str, _T]]]
# TODO(andrei): Turn all magic dictionaries to pydantic models.
# Normalization pipelines combine document scores from multiple query clauses.
# The number and ordering of weights should match the query clauses. The values
# of the weights should sum to 1.
def _get_hybrid_search_normalization_weights() -> list[float]:
if (
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
@@ -316,6 +326,9 @@ class DocumentQuery:
it MUST be supplied in addition to a search pipeline. The results from
hybrid search are not meaningful without that step.
TODO(andrei): There is some duplicated logic in this function with
others in this file.
Args:
query_text: The text to query for.
query_vector: The vector embedding of the text to query for.
@@ -419,6 +432,9 @@ class DocumentQuery:
This query can be directly supplied to the OpenSearch client.
TODO(andrei): There is some duplicated logic in this function with
others in this file.
Args:
query_text: The text to query for.
num_hits: The final number of hits to return.
@@ -498,6 +514,9 @@ class DocumentQuery:
This query can be directly supplied to the OpenSearch client.
TODO(andrei): There is some duplicated logic in this function with
others in this file.
Args:
query_embedding: The vector embedding of the text to query for.
num_hits: The final number of hits to return.
@@ -763,8 +782,9 @@ class DocumentQuery:
TITLE_FIELD_NAME: {
"query": query_text,
"operator": "or",
# The title fields are strongly discounted as they are included in the content.
# It just acts as a minor boost
# The title fields are strongly discounted as
# they are included in the content. This just
# acts as a minor boost.
"boost": 0.1,
}
}
@@ -779,6 +799,9 @@ class DocumentQuery:
}
},
{
# Analyzes the query and returns results which match any
# of the query's terms. More matches result in higher
# scores.
"match": {
CONTENT_FIELD_NAME: {
"query": query_text,
@@ -788,18 +811,21 @@ class DocumentQuery:
}
},
{
# Matches an exact phrase in a specified order.
"match_phrase": {
CONTENT_FIELD_NAME: {
"query": query_text,
# The number of words permitted between words of
# a query phrase and still result in a match.
"slop": 1,
"boost": 1.5,
}
}
},
],
# Ensure at least one term from the query is present in the
# document. This defaults to 1, unless a filter or must clause
# is supplied, in which case it defaults to 0.
# Ensures at least one match subquery from the query is present
# in the document. This defaults to 1, unless a filter or must
# clause is supplied, in which case it defaults to 0.
"minimum_should_match": 1,
}
}
@@ -833,7 +859,14 @@ class DocumentQuery:
The "filter" key applies a logical AND operator to its elements, so
every subfilter must evaluate to true in order for the document to be
retrieved. This function returns a list of such subfilters.
See https://docs.opensearch.org/latest/query-dsl/compound/bool/
See https://docs.opensearch.org/latest/query-dsl/compound/bool/.
TODO(ENG-3874): The terms queries returned by this function can be made
more performant for large cardinality sets by sorting the values by
their UTF-8 byte order.
TODO(ENG-3875): This function can take even better advantage of filter
caching by grouping "static" filters together into one sub-clause.
Args:
tenant_state: Tenant state containing the tenant ID.
@@ -878,6 +911,14 @@ class DocumentQuery:
the assistant. Matches chunks where ancestor_hierarchy_node_ids
contains any of these values.
Raises:
ValueError: document_id and attached_document_ids were supplied
together. This is not allowed because they operate on the same
schema field, and it does not semantically make sense to use
them together.
ValueError: Too many of one of the collection arguments was
supplied.
Returns:
A list of filters to be passed into the "filter" key of a search
query.
@@ -885,61 +926,156 @@ class DocumentQuery:
def _get_acl_visibility_filter(
access_control_list: list[str],
) -> dict[str, Any]:
) -> dict[str, dict[str, list[TermQuery[bool] | TermsQuery[str]] | int]]:
"""Returns a filter for the access control list.
Since this returns an isolated bool should clause, it can be cached
in OpenSearch independently of other clauses in _get_search_filters.
Args:
access_control_list: The access control list to restrict
documents to.
Raises:
ValueError: The number of access control list entries is greater
than MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
Returns:
A filter for the access control list.
"""
# Logical OR operator on its elements.
acl_visibility_filter: dict[str, Any] = {"bool": {"should": []}}
acl_visibility_filter["bool"]["should"].append(
{"term": {PUBLIC_FIELD_NAME: {"value": True}}}
)
for acl in access_control_list:
acl_subclause: dict[str, Any] = {
"term": {ACCESS_CONTROL_LIST_FIELD_NAME: {"value": acl}}
acl_visibility_filter: dict[str, dict[str, Any]] = {
"bool": {
"should": [{"term": {PUBLIC_FIELD_NAME: {"value": True}}}],
"minimum_should_match": 1,
}
}
if access_control_list:
if len(access_control_list) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
raise ValueError(
f"Too many access control list entries: {len(access_control_list)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
)
# Use terms instead of a list of term within a should clause
# because Lucene will optimize the filtering for large sets of
# terms. Small sets of terms are not expected to perform any
# differently than individual term clauses.
acl_subclause: TermsQuery[str] = {
"terms": {ACCESS_CONTROL_LIST_FIELD_NAME: list(access_control_list)}
}
acl_visibility_filter["bool"]["should"].append(acl_subclause)
return acl_visibility_filter
def _get_source_type_filter(
source_types: list[DocumentSource],
) -> dict[str, Any]:
# Logical OR operator on its elements.
source_type_filter: dict[str, Any] = {"bool": {"should": []}}
for source_type in source_types:
source_type_filter["bool"]["should"].append(
{"term": {SOURCE_TYPE_FIELD_NAME: {"value": source_type.value}}}
) -> TermsQuery[str]:
"""Returns a filter for the source types.
Since this returns an isolated terms clause, it can be cached in
OpenSearch independently of other clauses in _get_search_filters.
Args:
source_types: The source types to restrict documents to.
Raises:
ValueError: The number of source types is greater than
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
ValueError: An empty list was supplied.
Returns:
A filter for the source types.
"""
if not source_types:
raise ValueError(
"source_types cannot be empty if trying to create a source type filter."
)
return source_type_filter
def _get_tag_filter(tags: list[Tag]) -> dict[str, Any]:
# Logical OR operator on its elements.
tag_filter: dict[str, Any] = {"bool": {"should": []}}
for tag in tags:
# Kind of an abstraction leak, see
# convert_metadata_dict_to_list_of_strings for why metadata list
# entries are expected to look this way.
tag_str = f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}"
tag_filter["bool"]["should"].append(
{"term": {METADATA_LIST_FIELD_NAME: {"value": tag_str}}}
if len(source_types) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
raise ValueError(
f"Too many source types: {len(source_types)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
)
return tag_filter
# Use terms instead of a list of term within a should clause because
# Lucene will optimize the filtering for large sets of terms. Small
# sets of terms are not expected to perform any differently than
# individual term clauses.
return {
"terms": {
SOURCE_TYPE_FIELD_NAME: [
source_type.value for source_type in source_types
]
}
}
def _get_document_set_filter(document_sets: list[str]) -> dict[str, Any]:
# Logical OR operator on its elements.
document_set_filter: dict[str, Any] = {"bool": {"should": []}}
for document_set in document_sets:
document_set_filter["bool"]["should"].append(
{"term": {DOCUMENT_SETS_FIELD_NAME: {"value": document_set}}}
def _get_tag_filter(tags: list[Tag]) -> TermsQuery[str]:
"""Returns a filter for the tags.
Since this returns an isolated terms clause, it can be cached in
OpenSearch independently of other clauses in _get_search_filters.
Args:
tags: The tags to restrict documents to.
Raises:
ValueError: The number of tags is greater than
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
ValueError: An empty list was supplied.
Returns:
A filter for the tags.
"""
if not tags:
raise ValueError(
"tags cannot be empty if trying to create a tag filter."
)
return document_set_filter
if len(tags) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
raise ValueError(
f"Too many tags: {len(tags)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
)
# Kind of an abstraction leak, see
# convert_metadata_dict_to_list_of_strings for why metadata list
# entries are expected to look this way.
tag_str_list = [
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in tags
]
# Use terms instead of a list of term within a should clause because
# Lucene will optimize the filtering for large sets of terms. Small
# sets of terms are not expected to perform any differently than
# individual term clauses.
return {"terms": {METADATA_LIST_FIELD_NAME: tag_str_list}}
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
# Logical OR operator on its elements.
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
user_project_filter["bool"]["should"].append(
{"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
)
return user_project_filter
def _get_document_set_filter(document_sets: list[str]) -> TermsQuery[str]:
"""Returns a filter for the document sets.
def _get_persona_filter(persona_id: int) -> dict[str, Any]:
Since this returns an isolated terms clause, it can be cached in
OpenSearch independently of other clauses in _get_search_filters.
Args:
document_sets: The document sets to restrict documents to.
Raises:
ValueError: The number of document sets is greater than
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
ValueError: An empty list was supplied.
Returns:
A filter for the document sets.
"""
if not document_sets:
raise ValueError(
"document_sets cannot be empty if trying to create a document set filter."
)
if len(document_sets) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
raise ValueError(
f"Too many document sets: {len(document_sets)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
)
# Use terms instead of a list of term within a should clause because
# Lucene will optimize the filtering for large sets of terms. Small
# sets of terms are not expected to perform any differently than
# individual term clauses.
return {"terms": {DOCUMENT_SETS_FIELD_NAME: list(document_sets)}}
def _get_user_project_filter(project_id: int) -> TermQuery[int]:
return {"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
def _get_persona_filter(persona_id: int) -> TermQuery[int]:
return {"term": {PERSONAS_FIELD_NAME: {"value": persona_id}}}
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
@@ -947,7 +1083,9 @@ class DocumentQuery:
# document data.
time_cutoff = set_or_convert_timezone_to_utc(time_cutoff)
# Logical OR operator on its elements.
time_cutoff_filter: dict[str, Any] = {"bool": {"should": []}}
time_cutoff_filter: dict[str, Any] = {
"bool": {"should": [], "minimum_should_match": 1}
}
time_cutoff_filter["bool"]["should"].append(
{
"range": {
@@ -982,25 +1120,77 @@ class DocumentQuery:
def _get_attached_document_id_filter(
doc_ids: list[str],
) -> dict[str, Any]:
"""Filter for documents explicitly attached to an assistant."""
# Logical OR operator on its elements.
doc_id_filter: dict[str, Any] = {"bool": {"should": []}}
for doc_id in doc_ids:
doc_id_filter["bool"]["should"].append(
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": doc_id}}}
) -> TermsQuery[str]:
"""
Returns a filter for documents explicitly attached to an assistant.
Since this returns an isolated terms clause, it can be cached in
OpenSearch independently of other clauses in _get_search_filters.
Args:
doc_ids: The document IDs to restrict documents to.
Raises:
ValueError: The number of document IDs is greater than
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
ValueError: An empty list was supplied.
Returns:
A filter for the document IDs.
"""
if not doc_ids:
raise ValueError(
"doc_ids cannot be empty if trying to create a document ID filter."
)
return doc_id_filter
if len(doc_ids) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
raise ValueError(
f"Too many document IDs: {len(doc_ids)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
)
# Use terms instead of a list of term within a should clause because
# Lucene will optimize the filtering for large sets of terms. Small
# sets of terms are not expected to perform any differently than
# individual term clauses.
return {"terms": {DOCUMENT_ID_FIELD_NAME: list(doc_ids)}}
def _get_hierarchy_node_filter(
node_ids: list[int],
) -> dict[str, Any]:
"""Filter for chunks whose ancestors include any of the given hierarchy nodes.
Uses a terms query to check if ancestor_hierarchy_node_ids contains
any of the specified node IDs.
) -> TermsQuery[int]:
"""
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
Returns a filter for chunks whose ancestors include any of the given
hierarchy nodes.
Since this returns an isolated terms clause, it can be cached in
OpenSearch independently of other clauses in _get_search_filters.
Args:
node_ids: The hierarchy node IDs to restrict documents to.
Raises:
ValueError: The number of hierarchy node IDs is greater than
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
ValueError: An empty list was supplied.
Returns:
A filter for the hierarchy node IDs.
"""
if not node_ids:
raise ValueError(
"node_ids cannot be empty if trying to create a hierarchy node ID filter."
)
if len(node_ids) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
raise ValueError(
f"Too many hierarchy node IDs: {len(node_ids)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
)
# Use terms instead of a list of term within a should clause because
# Lucene will optimize the filtering for large sets of terms. Small
# sets of terms are not expected to perform any differently than
# individual term clauses.
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: list(node_ids)}}
if document_id is not None and attached_document_ids is not None:
raise ValueError(
"document_id and attached_document_ids cannot be used together."
)
filter_clauses: list[dict[str, Any]] = []
@@ -1045,6 +1235,9 @@ class DocumentQuery:
)
if has_knowledge_scope:
# Since this returns an isolated bool should clause, it can be
# cached in OpenSearch independently of other clauses in
# _get_search_filters.
knowledge_filter: dict[str, Any] = {
"bool": {"should": [], "minimum_should_match": 1}
}

View File

@@ -610,6 +610,22 @@ class VespaDocumentIndex(DocumentIndex):
return cleanup_content_for_chunks(query_vespa(params))
def keyword_retrieval(
self,
query: str,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
raise NotImplementedError
def semantic_retrieval(
self,
query_embedding: Embedding,
filters: IndexFilters,
num_to_retrieve: int,
) -> list[InferenceChunk]:
raise NotImplementedError
def random_retrieval(
self,
filters: IndexFilters,

View File

@@ -15,7 +15,7 @@ Usage (Celery tasks and FastAPI handlers):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is a validated Pydantic model instance (spec.response_model)
# result is a validated Pydantic model instance (response_type)
...
is_reachable update policy

View File

@@ -91,6 +91,8 @@ class HookResponse(BaseModel):
# Nullable to match the DB column — endpoint_url is required on creation but
# future hook point types may not use an external endpoint (e.g. built-in handlers).
endpoint_url: str | None
# Partially-masked API key (e.g. "abcd••••••••wxyz"), or None if no key is set.
api_key_masked: str | None
fail_strategy: HookFailStrategy
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
is_active: bool

View File

@@ -26,6 +26,8 @@ class DocumentIngestionSpec(HookPointSpec):
default_timeout_seconds = 30.0
fail_hard_description = "The document will not be indexed."
default_fail_strategy = HookFailStrategy.HARD
# TODO(Bo-Onyx): update later
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.ue263ual5vdi"
payload_model = DocumentIngestionPayload
response_model = DocumentIngestionResponse

View File

@@ -65,6 +65,8 @@ class QueryProcessingSpec(HookPointSpec):
"The query will be blocked and the user will see an error message."
)
default_fail_strategy = HookFailStrategy.HARD
# TODO(Bo-Onyx): update later
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.g2r1a1699u87"
payload_model = QueryProcessingPayload
response_model = QueryProcessingResponse

View File

@@ -11,6 +11,7 @@ class LLMOverride(BaseModel):
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None
display_name: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

View File

@@ -690,9 +690,9 @@
}
},
"node_modules/@dotenvx/dotenvx/node_modules/picomatch": {
"version": "4.0.3",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"version": "4.0.4",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
"license": "MIT",
"engines": {
"node": ">=12"
@@ -9537,9 +9537,9 @@
"license": "ISC"
},
"node_modules/picomatch": {
"version": "2.3.1",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz",
"integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==",
"version": "2.3.2",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
"license": "MIT",
"engines": {
"node": ">=8.6"
@@ -11118,9 +11118,9 @@
}
},
"node_modules/tinyglobby/node_modules/picomatch": {
"version": "4.0.3",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"version": "4.0.4",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
"dev": true,
"license": "MIT",
"engines": {

View File

@@ -62,6 +62,9 @@ def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookRespo
name=hook.name,
hook_point=hook.hook_point,
endpoint_url=hook.endpoint_url,
api_key_masked=(
hook.api_key.get_value(apply_mask=True) if hook.api_key else None
),
fail_strategy=hook.fail_strategy,
timeout_seconds=hook.timeout_seconds,
is_active=hook.is_active,
@@ -220,8 +223,8 @@ def create_hook(
db_session: Session = Depends(get_session),
) -> HookResponse:
"""Create a new hook. The endpoint is validated before persisting — creation fails if
the endpoint cannot be reached or the api_key is invalid. Hooks are created inactive;
use POST /{hook_id}/activate once ready to receive traffic."""
the endpoint cannot be reached or the api_key is invalid. Hooks are created active.
"""
spec = get_hook_point_spec(req.hook_point)
api_key = req.api_key.get_secret_value() if req.api_key else None
validation = _validate_endpoint(
@@ -240,9 +243,10 @@ def create_hook(
api_key=api_key,
fail_strategy=req.fail_strategy or spec.default_fail_strategy,
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
is_active=True,
is_reachable=True,
creator_id=user.id,
)
hook.is_reachable = True
db_session.commit()
return _hook_to_response(hook, creator_email=user.email)

View File

@@ -0,0 +1,207 @@
"""Generic Celery task lifecycle Prometheus metrics.
Provides signal handlers that track task started/completed/failed counts,
active task gauge, task duration histograms, and retry/reject/revoke counts.
These fire for ALL tasks on the worker — no per-connector enrichment
(see indexing_task_metrics.py for that).
Usage in a worker app module:
from onyx.server.metrics.celery_task_metrics import (
on_celery_task_prerun,
on_celery_task_postrun,
on_celery_task_retry,
on_celery_task_revoked,
on_celery_task_rejected,
)
# Call from the worker's existing signal handlers
"""
import threading
import time
from celery import Task
from prometheus_client import Counter
from prometheus_client import Gauge
from prometheus_client import Histogram
from onyx.utils.logger import setup_logger
logger = setup_logger()
TASK_STARTED = Counter(
"onyx_celery_task_started_total",
"Total Celery tasks started",
["task_name", "queue"],
)
TASK_COMPLETED = Counter(
"onyx_celery_task_completed_total",
"Total Celery tasks completed",
["task_name", "queue", "outcome"],
)
TASK_DURATION = Histogram(
"onyx_celery_task_duration_seconds",
"Celery task execution duration in seconds",
["task_name", "queue"],
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
)
TASKS_ACTIVE = Gauge(
"onyx_celery_tasks_active",
"Currently executing Celery tasks",
["task_name", "queue"],
)
TASK_RETRIED = Counter(
"onyx_celery_task_retried_total",
"Total Celery tasks retried",
["task_name", "queue"],
)
TASK_REVOKED = Counter(
"onyx_celery_task_revoked_total",
"Total Celery tasks revoked (cancelled)",
["task_name"],
)
TASK_REJECTED = Counter(
"onyx_celery_task_rejected_total",
"Total Celery tasks rejected by worker",
["task_name"],
)
# task_id → (monotonic start time, metric labels)
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
# Lock protecting _task_start_times — prerun, postrun, and eviction may
# run concurrently on thread-pool workers.
_task_start_times_lock = threading.Lock()
# Entries older than this are evicted on each prerun to prevent unbounded
# growth when tasks are killed (SIGTERM, OOM) and postrun never fires.
_MAX_START_TIME_AGE_SECONDS = 3600 # 1 hour
def _evict_stale_start_times() -> None:
"""Remove _task_start_times entries older than _MAX_START_TIME_AGE_SECONDS.
Must be called while holding _task_start_times_lock.
"""
now = time.monotonic()
stale_ids = [
tid
for tid, (start, _labels) in _task_start_times.items()
if now - start > _MAX_START_TIME_AGE_SECONDS
]
for tid in stale_ids:
entry = _task_start_times.pop(tid, None)
if entry is not None:
_labels = entry[1]
# Decrement active gauge for evicted tasks — these tasks were
# started but never completed (killed, OOM, etc.).
active_gauge = TASKS_ACTIVE.labels(**_labels)
if active_gauge._value.get() > 0:
active_gauge.dec()
def _get_task_labels(task: Task) -> dict[str, str]:
"""Extract task_name and queue labels from a Celery Task instance."""
task_name = task.name or "unknown"
queue = "unknown"
try:
delivery_info = task.request.delivery_info
if delivery_info:
queue = delivery_info.get("routing_key") or "unknown"
except AttributeError:
pass
return {"task_name": task_name, "queue": queue}
def on_celery_task_prerun(
task_id: str | None,
task: Task | None,
) -> None:
"""Record task start. Call from the worker's task_prerun signal handler."""
if task is None or task_id is None:
return
try:
labels = _get_task_labels(task)
TASK_STARTED.labels(**labels).inc()
TASKS_ACTIVE.labels(**labels).inc()
with _task_start_times_lock:
_evict_stale_start_times()
_task_start_times[task_id] = (time.monotonic(), labels)
except Exception:
logger.debug("Failed to record celery task prerun metrics", exc_info=True)
def on_celery_task_postrun(
task_id: str | None,
task: Task | None,
state: str | None,
) -> None:
"""Record task completion. Call from the worker's task_postrun signal handler."""
if task is None or task_id is None:
return
try:
labels = _get_task_labels(task)
outcome = "success" if state == "SUCCESS" else "failure"
TASK_COMPLETED.labels(**labels, outcome=outcome).inc()
# Guard against going below 0 if postrun fires without a matching
# prerun (e.g. after a worker restart or stale entry eviction).
active_gauge = TASKS_ACTIVE.labels(**labels)
if active_gauge._value.get() > 0:
active_gauge.dec()
with _task_start_times_lock:
entry = _task_start_times.pop(task_id, None)
if entry is not None:
start_time, _stored_labels = entry
TASK_DURATION.labels(**labels).observe(time.monotonic() - start_time)
except Exception:
logger.debug("Failed to record celery task postrun metrics", exc_info=True)
def on_celery_task_retry(
_task_id: str | None,
task: Task | None,
) -> None:
"""Record task retry. Call from the worker's task_retry signal handler."""
if task is None:
return
try:
labels = _get_task_labels(task)
TASK_RETRIED.labels(**labels).inc()
except Exception:
logger.debug("Failed to record celery task retry metrics", exc_info=True)
def on_celery_task_revoked(
_task_id: str | None,
task_name: str | None = None,
) -> None:
"""Record task revocation. The revoked signal doesn't provide a Task
instance, only the task name via sender."""
if task_name is None:
return
try:
TASK_REVOKED.labels(task_name=task_name).inc()
except Exception:
logger.debug("Failed to record celery task revoked metrics", exc_info=True)
def on_celery_task_rejected(
_task_id: str | None,
task_name: str | None = None,
) -> None:
"""Record task rejection."""
if task_name is None:
return
try:
TASK_REJECTED.labels(task_name=task_name).inc()
except Exception:
logger.debug("Failed to record celery task rejected metrics", exc_info=True)

View File

@@ -0,0 +1,528 @@
"""Prometheus collectors for Celery queue depths and indexing pipeline state.
These collectors query Redis and Postgres at scrape time (the Collector pattern),
so metrics are always fresh when Prometheus scrapes /metrics. They run inside the
monitoring celery worker which already has Redis and DB access.
To avoid hammering Redis/Postgres on every 15s scrape, results are cached with
a configurable TTL (default 30s). This means metrics may be up to TTL seconds
stale, which is fine for monitoring dashboards.
"""
import json
import threading
import time
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
from prometheus_client.core import GaugeMetricFamily
from prometheus_client.registry import Collector
from redis import Redis
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.constants import OnyxCeleryQueues
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Default cache TTL in seconds. Scrapes hitting within this window return
# the previous result without re-querying Redis/Postgres.
_DEFAULT_CACHE_TTL = 30.0
_QUEUE_LABEL_MAP: dict[str, str] = {
OnyxCeleryQueues.PRIMARY: "primary",
OnyxCeleryQueues.DOCPROCESSING: "docprocessing",
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING: "docfetching",
OnyxCeleryQueues.VESPA_METADATA_SYNC: "vespa_metadata_sync",
OnyxCeleryQueues.CONNECTOR_DELETION: "connector_deletion",
OnyxCeleryQueues.CONNECTOR_PRUNING: "connector_pruning",
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC: "permissions_sync",
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC: "external_group_sync",
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT: "permissions_upsert",
OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING: "hierarchy_fetching",
OnyxCeleryQueues.LLM_MODEL_UPDATE: "llm_model_update",
OnyxCeleryQueues.CHECKPOINT_CLEANUP: "checkpoint_cleanup",
OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP: "index_attempt_cleanup",
OnyxCeleryQueues.CSV_GENERATION: "csv_generation",
OnyxCeleryQueues.USER_FILE_PROCESSING: "user_file_processing",
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC: "user_file_project_sync",
OnyxCeleryQueues.USER_FILE_DELETE: "user_file_delete",
OnyxCeleryQueues.MONITORING: "monitoring",
OnyxCeleryQueues.SANDBOX: "sandbox",
OnyxCeleryQueues.OPENSEARCH_MIGRATION: "opensearch_migration",
}
# Queues where prefetched (unacked) task counts are meaningful
_UNACKED_QUEUES: list[str] = [
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
OnyxCeleryQueues.DOCPROCESSING,
]
class _CachedCollector(Collector):
"""Base collector with TTL-based caching.
Subclasses implement ``_collect_fresh()`` to query the actual data source.
The base ``collect()`` returns cached results if the TTL hasn't expired,
avoiding repeated queries when Prometheus scrapes frequently.
"""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
self._cache_ttl = cache_ttl
self._cached_result: list[GaugeMetricFamily] | None = None
self._last_collect_time: float = 0.0
self._lock = threading.Lock()
def collect(self) -> list[GaugeMetricFamily]:
with self._lock:
now = time.monotonic()
if (
now - self._last_collect_time < self._cache_ttl
and self._cached_result is not None
):
return self._cached_result
try:
result = self._collect_fresh()
self._cached_result = result
self._last_collect_time = now
return result
except Exception:
logger.exception(f"Error in {type(self).__name__}.collect()")
# Return stale cache on error rather than nothing — avoids
# metrics disappearing during transient failures.
return self._cached_result if self._cached_result is not None else []
def _collect_fresh(self) -> list[GaugeMetricFamily]:
raise NotImplementedError
def describe(self) -> list[GaugeMetricFamily]:
return []
class QueueDepthCollector(_CachedCollector):
"""Reads Celery queue lengths from the broker Redis on each scrape.
Uses a Redis client factory (callable) rather than a stored client
reference so the connection is always fresh from Celery's pool.
"""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._get_redis: Callable[[], Redis] | None = None
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._get_redis is None:
return []
redis_client = self._get_redis()
depth = GaugeMetricFamily(
"onyx_queue_depth",
"Number of tasks waiting in Celery queue",
labels=["queue"],
)
unacked = GaugeMetricFamily(
"onyx_queue_unacked",
"Number of prefetched (unacked) tasks for queue",
labels=["queue"],
)
queue_age = GaugeMetricFamily(
"onyx_queue_oldest_task_age_seconds",
"Age of the oldest task in the queue (seconds since enqueue)",
labels=["queue"],
)
now = time.time()
for queue_name, label in _QUEUE_LABEL_MAP.items():
length = celery_get_queue_length(queue_name, redis_client)
depth.add_metric([label], length)
# Peek at the oldest message to get its age
if length > 0:
age = self._get_oldest_message_age(redis_client, queue_name, now)
if age is not None:
queue_age.add_metric([label], age)
for queue_name in _UNACKED_QUEUES:
label = _QUEUE_LABEL_MAP[queue_name]
task_ids = celery_get_unacked_task_ids(queue_name, redis_client)
unacked.add_metric([label], len(task_ids))
return [depth, unacked, queue_age]
@staticmethod
def _get_oldest_message_age(
redis_client: Redis, queue_name: str, now: float
) -> float | None:
"""Peek at the oldest (tail) message in a Redis list queue
and extract its timestamp to compute age.
Note: If the Celery message contains neither ``properties.timestamp``
nor ``headers.timestamp``, no age metric is emitted for this queue.
This can happen with custom task producers or non-standard Celery
protocol versions. The metric will simply be absent rather than
inaccurate, which is the safest behavior for alerting.
"""
try:
raw: bytes | str | None = redis_client.lindex(queue_name, -1) # type: ignore[assignment]
if raw is None:
return None
msg = json.loads(raw)
# Check for ETA tasks first — they are intentionally delayed,
# so reporting their queue age would be misleading.
headers = msg.get("headers", {})
if headers.get("eta") is not None:
return None
# Celery v2 protocol: timestamp in properties
props = msg.get("properties", {})
ts = props.get("timestamp")
if ts is not None:
return now - float(ts)
# Fallback: some Celery configurations place the timestamp in
# headers instead of properties.
ts = headers.get("timestamp")
if ts is not None:
return now - float(ts)
except Exception:
pass
return None
class IndexAttemptCollector(_CachedCollector):
"""Queries Postgres for index attempt state on each scrape."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._configured: bool = False
self._terminal_statuses: list = []
def configure(self) -> None:
"""Call once DB engine is initialized."""
from onyx.db.enums import IndexingStatus
self._terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
self._configured = True
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if not self._configured:
return []
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.index_attempt import get_active_index_attempts_for_metrics
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
attempts_gauge = GaugeMetricFamily(
"onyx_index_attempts_active",
"Number of non-terminal index attempts",
labels=[
"status",
"source",
"tenant_id",
"connector_name",
"cc_pair_id",
],
)
tenant_ids = get_all_tenant_ids()
for tid in tenant_ids:
# Defensive guard — get_all_tenant_ids() should never yield None,
# but we guard here for API stability in case the contract changes.
if tid is None:
continue
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
try:
with get_session_with_current_tenant() as session:
rows = get_active_index_attempts_for_metrics(session)
for status, source, cc_id, cc_name, count in rows:
name_val = cc_name or f"cc_pair_{cc_id}"
attempts_gauge.add_metric(
[
status.value,
source.value,
tid,
name_val,
str(cc_id),
],
count,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return [attempts_gauge]
class ConnectorHealthCollector(_CachedCollector):
"""Queries Postgres for connector health state on each scrape."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._configured: bool = False
def configure(self) -> None:
"""Call once DB engine is initialized."""
self._configured = True
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if not self._configured:
return []
from onyx.db.connector_credential_pair import (
get_connector_health_for_metrics,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.index_attempt import get_docs_indexed_by_cc_pair
from onyx.db.index_attempt import get_failed_attempt_counts_by_cc_pair
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
staleness_gauge = GaugeMetricFamily(
"onyx_connector_last_success_age_seconds",
"Seconds since last successful index for this connector",
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
)
error_state_gauge = GaugeMetricFamily(
"onyx_connector_in_error_state",
"Whether the connector is in a repeated error state (1=yes, 0=no)",
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
)
by_status_gauge = GaugeMetricFamily(
"onyx_connectors_by_status",
"Number of connectors grouped by status",
labels=["tenant_id", "status"],
)
error_total_gauge = GaugeMetricFamily(
"onyx_connectors_in_error_total",
"Total number of connectors in repeated error state",
labels=["tenant_id"],
)
per_connector_labels = [
"tenant_id",
"source",
"cc_pair_id",
"connector_name",
]
docs_success_gauge = GaugeMetricFamily(
"onyx_connector_docs_indexed",
"Total new documents indexed (90-day rolling sum) per connector",
labels=per_connector_labels,
)
docs_error_gauge = GaugeMetricFamily(
"onyx_connector_error_count",
"Total number of failed index attempts per connector",
labels=per_connector_labels,
)
now = datetime.now(tz=timezone.utc)
tenant_ids = get_all_tenant_ids()
for tid in tenant_ids:
# Defensive guard — get_all_tenant_ids() should never yield None,
# but we guard here for API stability in case the contract changes.
if tid is None:
continue
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
try:
with get_session_with_current_tenant() as session:
pairs = get_connector_health_for_metrics(session)
error_counts_by_cc = get_failed_attempt_counts_by_cc_pair(session)
docs_by_cc = get_docs_indexed_by_cc_pair(session)
status_counts: dict[str, int] = {}
error_count = 0
for (
cc_id,
status,
in_error,
last_success,
cc_name,
source,
) in pairs:
cc_id_str = str(cc_id)
source_val = source.value
name_val = cc_name or f"cc_pair_{cc_id}"
label_vals = [tid, source_val, cc_id_str, name_val]
if last_success is not None:
# Both `now` and `last_success` are timezone-aware
# (the DB column uses DateTime(timezone=True)),
# so subtraction is safe.
age = (now - last_success).total_seconds()
staleness_gauge.add_metric(label_vals, age)
error_state_gauge.add_metric(
label_vals,
1.0 if in_error else 0.0,
)
if in_error:
error_count += 1
docs_success_gauge.add_metric(
label_vals,
docs_by_cc.get(cc_id, 0),
)
docs_error_gauge.add_metric(
label_vals,
error_counts_by_cc.get(cc_id, 0),
)
status_val = status.value
status_counts[status_val] = status_counts.get(status_val, 0) + 1
for status_val, count in status_counts.items():
by_status_gauge.add_metric([tid, status_val], count)
error_total_gauge.add_metric([tid], error_count)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return [
staleness_gauge,
error_state_gauge,
by_status_gauge,
error_total_gauge,
docs_success_gauge,
docs_error_gauge,
]
class RedisHealthCollector(_CachedCollector):
"""Collects Redis server health metrics (memory, clients, etc.)."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._get_redis: Callable[[], Redis] | None = None
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._get_redis is None:
return []
redis_client = self._get_redis()
memory_used = GaugeMetricFamily(
"onyx_redis_memory_used_bytes",
"Redis used memory in bytes",
)
memory_peak = GaugeMetricFamily(
"onyx_redis_memory_peak_bytes",
"Redis peak used memory in bytes",
)
memory_frag = GaugeMetricFamily(
"onyx_redis_memory_fragmentation_ratio",
"Redis memory fragmentation ratio (>1.5 indicates fragmentation)",
)
connected_clients = GaugeMetricFamily(
"onyx_redis_connected_clients",
"Number of connected Redis clients",
)
try:
mem_info: dict = redis_client.info("memory") # type: ignore[assignment]
memory_used.add_metric([], mem_info.get("used_memory", 0))
memory_peak.add_metric([], mem_info.get("used_memory_peak", 0))
frag = mem_info.get("mem_fragmentation_ratio")
if frag is not None:
memory_frag.add_metric([], frag)
client_info: dict = redis_client.info("clients") # type: ignore[assignment]
connected_clients.add_metric([], client_info.get("connected_clients", 0))
except Exception:
logger.debug("Failed to collect Redis health metrics", exc_info=True)
return [memory_used, memory_peak, memory_frag, connected_clients]
class WorkerHealthCollector(_CachedCollector):
"""Collects Celery worker count and process count via inspect ping.
Uses a longer cache TTL (60s) since inspect.ping() is a broadcast
command that takes a couple seconds to complete.
Maintains a set of known worker short-names so that when a worker
stops responding, we emit ``up=0`` instead of silently dropping the
metric (which would make ``absent()``-style alerts impossible).
"""
# Remove a worker from _known_workers after this many consecutive
# missed pings (at 60s TTL ≈ 10 minutes of being unreachable).
_MAX_CONSECUTIVE_MISSES = 10
def __init__(self, cache_ttl: float = 60.0) -> None:
super().__init__(cache_ttl)
self._celery_app: Any | None = None
# worker short-name → consecutive miss count.
# Workers start at 0 and reset to 0 each time they respond.
# Removed after _MAX_CONSECUTIVE_MISSES missed collects.
self._known_workers: dict[str, int] = {}
def set_celery_app(self, app: Any) -> None:
"""Set the Celery app instance for inspect commands."""
self._celery_app = app
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._celery_app is None:
return []
active_workers = GaugeMetricFamily(
"onyx_celery_active_worker_count",
"Number of active Celery workers responding to ping",
)
worker_up = GaugeMetricFamily(
"onyx_celery_worker_up",
"Whether a specific Celery worker is alive (1=up, 0=down)",
labels=["worker"],
)
try:
inspector = self._celery_app.control.inspect(timeout=3.0)
ping_result = inspector.ping()
responding: set[str] = set()
if ping_result:
active_workers.add_metric([], len(ping_result))
for worker_name in ping_result:
# Strip hostname suffix for cleaner labels
short_name = worker_name.split("@")[0]
responding.add(short_name)
else:
active_workers.add_metric([], 0)
# Register newly-seen workers and reset miss count for
# workers that responded.
for short_name in responding:
self._known_workers[short_name] = 0
# Increment miss count for non-responding workers and evict
# those that have been missing too long.
stale = []
for short_name in list(self._known_workers):
if short_name not in responding:
self._known_workers[short_name] += 1
if self._known_workers[short_name] >= self._MAX_CONSECUTIVE_MISSES:
stale.append(short_name)
for short_name in stale:
del self._known_workers[short_name]
for short_name in sorted(self._known_workers):
worker_up.add_metric([short_name], 1 if short_name in responding else 0)
except Exception:
logger.debug("Failed to collect worker health metrics", exc_info=True)
return [active_workers, worker_up]

View File

@@ -0,0 +1,113 @@
"""Setup function for indexing pipeline Prometheus collectors.
Called once by the monitoring celery worker after Redis and DB are ready.
"""
from collections.abc import Callable
from typing import Any
from celery import Celery
from prometheus_client.registry import REGISTRY
from redis import Redis
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Module-level singletons — these are lightweight objects (no connections or DB
# state) until configure() / set_redis_factory() is called. Keeping them at
# module level ensures they survive the lifetime of the worker process and are
# only registered with the Prometheus registry once.
_queue_collector = QueueDepthCollector()
_attempt_collector = IndexAttemptCollector()
_connector_collector = ConnectorHealthCollector()
_redis_health_collector = RedisHealthCollector()
_worker_health_collector = WorkerHealthCollector()
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
"""Create a factory that returns a cached broker Redis client.
Reuses a single connection across scrapes to avoid leaking connections.
Reconnects automatically if the cached connection becomes stale.
"""
_cached_client: list[Redis | None] = [None]
# Keep a reference to the Kombu Connection so we can close it on
# reconnect (the raw Redis client outlives the Kombu wrapper).
_cached_kombu_conn: list[Any] = [None]
def _close_client(client: Redis) -> None:
"""Best-effort close of a Redis client."""
try:
client.close()
except Exception:
logger.debug("Failed to close stale Redis client", exc_info=True)
def _close_kombu_conn() -> None:
"""Best-effort close of the cached Kombu Connection."""
conn = _cached_kombu_conn[0]
if conn is not None:
try:
conn.close()
except Exception:
logger.debug("Failed to close Kombu connection", exc_info=True)
_cached_kombu_conn[0] = None
def _get_broker_redis() -> Redis:
client = _cached_client[0]
if client is not None:
try:
client.ping()
return client
except Exception:
logger.debug("Cached Redis client stale, reconnecting")
_close_client(client)
_cached_client[0] = None
_close_kombu_conn()
# Get a fresh Redis client from the broker connection.
# We hold this client long-term (cached above) rather than using a
# context manager, because we need it to persist across scrapes.
# The caching logic above ensures we only ever hold one connection,
# and we close it explicitly on reconnect.
conn = celery_app.broker_connection()
# kombu's Channel exposes .client at runtime (the underlying Redis
# client) but the type stubs don't declare it.
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
_cached_client[0] = new_client
_cached_kombu_conn[0] = conn
return new_client
return _get_broker_redis
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
"""Register all indexing pipeline collectors with the default registry.
Args:
celery_app: The Celery application instance. Used to obtain a fresh
broker Redis client on each scrape for queue depth metrics.
"""
redis_factory = _make_broker_redis_factory(celery_app)
_queue_collector.set_redis_factory(redis_factory)
_redis_health_collector.set_redis_factory(redis_factory)
_worker_health_collector.set_celery_app(celery_app)
_attempt_collector.configure()
_connector_collector.configure()
for collector in (
_queue_collector,
_attempt_collector,
_connector_collector,
_redis_health_collector,
_worker_health_collector,
):
try:
REGISTRY.register(collector)
except ValueError:
logger.debug("Collector already registered: %s", type(collector).__name__)

View File

@@ -0,0 +1,253 @@
"""Per-connector Prometheus metrics for indexing tasks.
Enriches the two primary indexing tasks (docfetching_proxy_task and
docprocessing_task) with connector-level labels: source, tenant_id,
and cc_pair_id.
Note: connector_name is intentionally excluded from push-based per-task
counters because it is a user-defined free-form string that can create
unbounded cardinality. The pull-based collectors on the monitoring worker
(see indexing_pipeline.py) include connector_name since they have bounded
cardinality (one series per connector, not per task execution).
Uses an in-memory cache for cc_pair_id → (source, name) lookups.
Connectors never change source type, and names change rarely, so the
cache is safe to hold for the worker's lifetime.
Usage in a worker app module:
from onyx.server.metrics.indexing_task_metrics import (
on_indexing_task_prerun,
on_indexing_task_postrun,
)
"""
import threading
import time
from dataclasses import dataclass
from celery import Task
from prometheus_client import Counter
from prometheus_client import Histogram
from onyx.configs.constants import OnyxCeleryTask
from onyx.server.metrics.celery_task_metrics import _MAX_START_TIME_AGE_SECONDS
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@dataclass(frozen=True)
class ConnectorInfo:
"""Cached connector metadata for metric labels."""
source: str
name: str
_UNKNOWN_CONNECTOR = ConnectorInfo(source="unknown", name="unknown")
# (tenant_id, cc_pair_id) → ConnectorInfo (populated on first encounter).
# Keyed by tenant to avoid cross-tenant cache poisoning in multi-tenant
# deployments where different tenants can share the same cc_pair_id value.
_connector_cache: dict[tuple[str, int], ConnectorInfo] = {}
# Lock protecting _connector_cache — multiple thread-pool workers may
# resolve connectors concurrently.
_connector_cache_lock = threading.Lock()
# Only enrich these task types with per-connector labels
_INDEXING_TASK_NAMES: frozenset[str] = frozenset(
{
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
OnyxCeleryTask.DOCPROCESSING_TASK,
}
)
# connector_name is intentionally excluded — see module docstring.
INDEXING_TASK_STARTED = Counter(
"onyx_indexing_task_started_total",
"Indexing tasks started per connector",
["task_name", "source", "tenant_id", "cc_pair_id"],
)
INDEXING_TASK_COMPLETED = Counter(
"onyx_indexing_task_completed_total",
"Indexing tasks completed per connector",
[
"task_name",
"source",
"tenant_id",
"cc_pair_id",
"outcome",
],
)
INDEXING_TASK_DURATION = Histogram(
"onyx_indexing_task_duration_seconds",
"Indexing task duration by connector type",
["task_name", "source", "tenant_id"],
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
)
# task_id → monotonic start time (for indexing tasks only)
_indexing_start_times: dict[str, float] = {}
# Lock protecting _indexing_start_times — prerun, postrun, and eviction may
# run concurrently on thread-pool workers.
_indexing_start_times_lock = threading.Lock()
def _evict_stale_start_times() -> None:
"""Remove _indexing_start_times entries older than _MAX_START_TIME_AGE_SECONDS.
Must be called while holding _indexing_start_times_lock.
"""
now = time.monotonic()
stale_ids = [
tid
for tid, start in _indexing_start_times.items()
if now - start > _MAX_START_TIME_AGE_SECONDS
]
for tid in stale_ids:
_indexing_start_times.pop(tid, None)
def _resolve_connector(cc_pair_id: int) -> ConnectorInfo:
"""Resolve cc_pair_id to ConnectorInfo, using cache when possible.
On cache miss, does a single DB query with eager connector load.
On any failure, returns _UNKNOWN_CONNECTOR without caching, so that
subsequent calls can retry the lookup once the DB is available.
Note on tenant_id source: we read CURRENT_TENANT_ID_CONTEXTVAR for the
cache key. The Celery tenant-aware middleware sets this contextvar before
task execution, and it always matches kwargs["tenant_id"] (which is set
at task dispatch time). They are guaranteed to agree for a given task
execution context.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get("") or ""
cache_key = (tenant_id, cc_pair_id)
with _connector_cache_lock:
cached = _connector_cache.get(cache_key)
if cached is not None:
return cached
try:
from onyx.db.connector_credential_pair import (
get_connector_credential_pair_from_id,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session,
cc_pair_id,
eager_load_connector=True,
)
if cc_pair is None:
# DB lookup succeeded but cc_pair doesn't exist — don't cache,
# it may appear later (race with connector creation).
return _UNKNOWN_CONNECTOR
info = ConnectorInfo(
source=cc_pair.connector.source.value,
name=cc_pair.name,
)
with _connector_cache_lock:
_connector_cache[cache_key] = info
return info
except Exception:
logger.debug(
f"Failed to resolve connector info for cc_pair_id={cc_pair_id}",
exc_info=True,
)
return _UNKNOWN_CONNECTOR
def on_indexing_task_prerun(
task_id: str | None,
task: Task | None,
kwargs: dict | None,
) -> None:
"""Record per-connector metrics at task start.
Only fires for tasks in _INDEXING_TASK_NAMES. Silently returns for
all other tasks.
"""
if task is None or task_id is None or kwargs is None:
return
task_name = task.name or ""
if task_name not in _INDEXING_TASK_NAMES:
return
try:
cc_pair_id = kwargs.get("cc_pair_id")
tenant_id = str(kwargs.get("tenant_id", "unknown"))
if cc_pair_id is None:
return
info = _resolve_connector(cc_pair_id)
INDEXING_TASK_STARTED.labels(
task_name=task_name,
source=info.source,
tenant_id=tenant_id,
cc_pair_id=str(cc_pair_id),
).inc()
with _indexing_start_times_lock:
_evict_stale_start_times()
_indexing_start_times[task_id] = time.monotonic()
except Exception:
logger.debug("Failed to record indexing task prerun metrics", exc_info=True)
def on_indexing_task_postrun(
task_id: str | None,
task: Task | None,
kwargs: dict | None,
state: str | None,
) -> None:
"""Record per-connector completion metrics.
Only fires for tasks in _INDEXING_TASK_NAMES.
"""
if task is None or task_id is None or kwargs is None:
return
task_name = task.name or ""
if task_name not in _INDEXING_TASK_NAMES:
return
try:
cc_pair_id = kwargs.get("cc_pair_id")
tenant_id = str(kwargs.get("tenant_id", "unknown"))
if cc_pair_id is None:
return
info = _resolve_connector(cc_pair_id)
outcome = "success" if state == "SUCCESS" else "failure"
INDEXING_TASK_COMPLETED.labels(
task_name=task_name,
source=info.source,
tenant_id=tenant_id,
cc_pair_id=str(cc_pair_id),
outcome=outcome,
).inc()
with _indexing_start_times_lock:
start = _indexing_start_times.pop(task_id, None)
if start is not None:
INDEXING_TASK_DURATION.labels(
task_name=task_name,
source=info.source,
tenant_id=tenant_id,
).observe(time.monotonic() - start)
except Exception:
logger.debug("Failed to record indexing task postrun metrics", exc_info=True)

View File

@@ -0,0 +1,89 @@
"""Standalone Prometheus metrics HTTP server for non-API processes.
The FastAPI API server already exposes /metrics via prometheus-fastapi-instrumentator.
Celery workers and other background processes use this module to expose their
own /metrics endpoint on a configurable port.
Usage:
from onyx.server.metrics.metrics_server import start_metrics_server
start_metrics_server("monitoring") # reads port from env or uses default
"""
import os
import threading
from prometheus_client import start_http_server
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Default ports for worker types that serve custom Prometheus metrics.
# Only add entries here when a worker actually registers collectors.
# In k8s each worker type runs in its own pod, so PROMETHEUS_METRICS_PORT
# env var can override.
_DEFAULT_PORTS: dict[str, int] = {
"monitoring": 9096,
"docfetching": 9092,
"docprocessing": 9093,
}
_server_started = False
_server_lock = threading.Lock()
def start_metrics_server(worker_type: str) -> int | None:
"""Start a Prometheus metrics HTTP server in a background thread.
Returns the port if started, None if disabled or already started.
Port resolution order:
1. PROMETHEUS_METRICS_PORT env var (explicit override)
2. Default port for the worker type
3. If worker type is unknown and no env var, skip
Set PROMETHEUS_METRICS_ENABLED=false to disable.
"""
global _server_started
with _server_lock:
if _server_started:
logger.debug(f"Metrics server already started for {worker_type}")
return None
enabled = os.environ.get("PROMETHEUS_METRICS_ENABLED", "true").lower()
if enabled in ("false", "0", "no"):
logger.info(f"Prometheus metrics server disabled for {worker_type}")
return None
port_str = os.environ.get("PROMETHEUS_METRICS_PORT")
if port_str:
try:
port = int(port_str)
except ValueError:
logger.warning(
f"Invalid PROMETHEUS_METRICS_PORT '{port_str}' for {worker_type}, "
"must be a numeric port. Skipping metrics server."
)
return None
elif worker_type in _DEFAULT_PORTS:
port = _DEFAULT_PORTS[worker_type]
else:
logger.info(
f"No default metrics port for worker type '{worker_type}' "
"and PROMETHEUS_METRICS_PORT not set. Skipping metrics server."
)
return None
try:
start_http_server(port)
_server_started = True
logger.info(
f"Prometheus metrics server started on :{port} for {worker_type}"
)
return port
except OSError as e:
logger.warning(
f"Failed to start metrics server on :{port} for {worker_type}: {e}"
)
return None

View File

@@ -0,0 +1,106 @@
"""Prometheus metrics for OpenSearch search latency and throughput.
Tracks client-side round-trip latency, server-side execution time (from
OpenSearch's ``took`` field), total search count, and in-flight concurrency.
"""
import logging
from collections.abc import Generator
from contextlib import contextmanager
from prometheus_client import Counter
from prometheus_client import Gauge
from prometheus_client import Histogram
from onyx.document_index.opensearch.constants import OpenSearchSearchType
logger = logging.getLogger(__name__)
_SEARCH_LATENCY_BUCKETS = (
0.005,
0.01,
0.025,
0.05,
0.1,
0.25,
0.5,
1.0,
2.5,
5.0,
10.0,
25.0,
)
_client_duration = Histogram(
"onyx_opensearch_search_client_duration_seconds",
"Client-side end-to-end latency of OpenSearch search calls",
["search_type"],
buckets=_SEARCH_LATENCY_BUCKETS,
)
_server_duration = Histogram(
"onyx_opensearch_search_server_duration_seconds",
"Server-side execution time reported by OpenSearch (took field)",
["search_type"],
buckets=_SEARCH_LATENCY_BUCKETS,
)
_search_total = Counter(
"onyx_opensearch_search_total",
"Total number of search requests sent to OpenSearch",
["search_type"],
)
_searches_in_progress = Gauge(
"onyx_opensearch_searches_in_progress",
"Number of OpenSearch searches currently in-flight",
["search_type"],
)
def observe_opensearch_search(
search_type: OpenSearchSearchType,
client_duration_s: float,
server_took_ms: int | None,
) -> None:
"""Records latency and throughput metrics for a completed OpenSearch search.
Args:
search_type: The type of search.
client_duration_s: Wall-clock duration measured on the client side, in
seconds.
server_took_ms: The ``took`` value from the OpenSearch response, in
milliseconds. May be ``None`` if the response did not include it.
"""
try:
label = search_type.value
_search_total.labels(search_type=label).inc()
_client_duration.labels(search_type=label).observe(client_duration_s)
if server_took_ms is not None:
_server_duration.labels(search_type=label).observe(server_took_ms / 1000.0)
except Exception:
logger.warning("Failed to record OpenSearch search metrics.", exc_info=True)
@contextmanager
def track_opensearch_search_in_progress(
search_type: OpenSearchSearchType,
) -> Generator[None, None, None]:
"""Context manager that tracks in-flight OpenSearch searches via a Gauge."""
incremented = False
label = search_type.value
try:
_searches_in_progress.labels(search_type=label).inc()
incremented = True
except Exception:
logger.warning("Failed to increment in-progress search gauge.", exc_info=True)
try:
yield
finally:
if incremented:
try:
_searches_in_progress.labels(search_type=label).dec()
except Exception:
logger.warning(
"Failed to decrement in-progress search gauge.", exc_info=True
)

View File

@@ -29,6 +29,7 @@ from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.process_message import run_multi_model_stream
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
from onyx.configs.app_configs import WEB_DOMAIN
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import set_preferred_response
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.session_loading import (
@@ -570,6 +575,46 @@ def handle_send_chat_message(
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
is_multi_model = (
chat_message_req.llm_overrides is not None
and len(chat_message_req.llm_overrides) > 1
)
if is_multi_model and chat_message_req.stream:
# Narrowed here; is_multi_model already checked llm_overrides is not None
llm_overrides = chat_message_req.llm_overrides or []
def multi_model_stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in run_multi_model_stream(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in multi-model streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(
multi_model_stream_generator(), media_type="text/event-stream"
)
if is_multi_model and not chat_message_req.stream:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
)
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -660,6 +705,30 @@ def set_message_as_latest(
)
@router.put("/set-preferred-response")
def set_preferred_response_endpoint(
request_body: SetPreferredResponseRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
"""Set the preferred assistant response for a multi-model turn."""
try:
# Ownership check: get_chat_message raises ValueError if the message
# doesn't belong to this user, preventing cross-user mutation.
get_chat_message(
chat_message_id=request_body.user_message_id,
user_id=user.id if user else None,
db_session=db_session,
)
set_preferred_response(
db_session=db_session,
user_message_id=request_body.user_message_id,
preferred_assistant_message_id=request_body.preferred_response_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,

View File

@@ -41,6 +41,21 @@ class MessageResponseIDInfo(BaseModel):
reserved_assistant_message_id: int
class ModelResponseSlot(BaseModel):
"""Pairs a reserved assistant message ID with its model display name."""
message_id: int
model_name: str
class MultiModelMessageResponseIDInfo(BaseModel):
"""Sent at the start of a multi-model streaming response.
Contains the user message ID and one slot per model being run in parallel."""
user_message_id: int | None
responses: list[ModelResponseSlot]
class SourceTag(Tag):
source: DocumentSource
@@ -86,6 +101,9 @@ class SendMessageRequest(BaseModel):
message: str
llm_override: LLMOverride | None = None
# For multi-model mode: up to 3 LLM overrides to run in parallel.
# When provided with >1 entry, triggers multi-model streaming.
llm_overrides: list[LLMOverride] | None = None
# Test-only override for deterministic LiteLLM mock responses.
mock_llm_response: str | None = None
@@ -211,6 +229,8 @@ class ChatMessageDetail(BaseModel):
error: str | None = None
current_feedback: str | None = None # "like" | "dislike" | null
processing_duration_seconds: float | None = None
preferred_response_id: int | None = None
model_display_name: str | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
@@ -218,6 +238,11 @@ class ChatMessageDetail(BaseModel):
return initial_dict
class SetPreferredResponseRequest(BaseModel):
user_message_id: int
preferred_response_id: int
class ChatSessionDetailResponse(BaseModel):
chat_session_id: UUID
description: str | None

View File

@@ -8,3 +8,5 @@ class Placement(BaseModel):
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
model_index: int | None = None

View File

@@ -17,6 +17,7 @@ from onyx.db.models import User
from onyx.db.notification import dismiss_all_notifications
from onyx.db.notification import get_notifications
from onyx.db.notification import update_notification_last_shown
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.features.build.utils import is_onyx_craft_enabled
@@ -80,6 +81,7 @@ def fetch_settings(
needs_reindexing=needs_reindexing,
onyx_craft_enabled=onyx_craft_enabled_for_user,
vector_db_enabled=not DISABLE_VECTOR_DB,
hooks_enabled=HOOKS_AVAILABLE,
version=onyx_version,
)

View File

@@ -104,5 +104,7 @@ class UserSettings(Settings):
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
# document sets are unavailable.
vector_db_enabled: bool = True
# True when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
hooks_enabled: bool = False
# Application version, read from the ONYX_VERSION env var at startup.
version: str | None = None

View File

@@ -736,7 +736,7 @@ if __name__ == "__main__":
llm.config.model_name, llm.config.model_provider
)
persona = get_default_behavior_persona(db_session)
persona = get_default_behavior_persona(db_session, eager_load_for_tools=True)
if persona is None:
raise ValueError("No default persona found")

View File

@@ -9,6 +9,7 @@ from onyx.chat.emitter import Emitter
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import PersonaSearchInfo
from onyx.db.enums import MCPAuthenticationPerformer
from onyx.db.enums import MCPAuthenticationType
from onyx.db.mcp import get_all_mcp_tools_for_server
@@ -124,7 +125,12 @@ def construct_tools(
) -> dict[int, list[Tool]]:
"""Constructs tools based on persona configuration and available APIs.
Will simply skip tools that are not allowed/available."""
Will simply skip tools that are not allowed/available.
Callers must supply a persona with ``tools``, ``document_sets``,
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
to avoid lazy SQL queries after the session may have been flushed."""
tool_dict: dict[int, list[Tool]] = {}
# Log which tools are attached to the persona for debugging
@@ -143,6 +149,28 @@ def construct_tools(
# This flow is for search so we do not get all indices.
document_index = get_default_document_index(search_settings, None, db_session)
def _build_search_tool(tool_id: int, config: SearchToolConfig) -> SearchTool:
persona_search_info = PersonaSearchInfo(
document_set_names=[ds.name for ds in persona.document_sets],
search_start_date=persona.search_start_date,
attached_document_ids=[doc.id for doc in persona.attached_documents],
hierarchy_node_ids=[node.id for node in persona.hierarchy_nodes],
)
return SearchTool(
tool_id=tool_id,
emitter=emitter,
user=user,
persona_search_info=persona_search_info,
llm=llm,
document_index=document_index,
user_selected_filters=config.user_selected_filters,
project_id_filter=config.project_id_filter,
persona_id_filter=config.persona_id_filter,
bypass_acl=config.bypass_acl,
slack_context=config.slack_context,
enable_slack_search=config.enable_slack_search,
)
added_search_tool = False
for db_tool_model in persona.tools:
# If allowed_tool_ids is specified, skip tools not in the allowed list
@@ -176,22 +204,9 @@ def construct_tools(
if not search_tool_config:
search_tool_config = SearchToolConfig()
search_tool = SearchTool(
tool_id=db_tool_model.id,
emitter=emitter,
user=user,
persona=persona,
llm=llm,
document_index=document_index,
user_selected_filters=search_tool_config.user_selected_filters,
project_id_filter=search_tool_config.project_id_filter,
persona_id_filter=search_tool_config.persona_id_filter,
bypass_acl=search_tool_config.bypass_acl,
slack_context=search_tool_config.slack_context,
enable_slack_search=search_tool_config.enable_slack_search,
)
tool_dict[db_tool_model.id] = [search_tool]
tool_dict[db_tool_model.id] = [
_build_search_tool(db_tool_model.id, search_tool_config)
]
# Handle Image Generation Tool
elif tool_cls.__name__ == ImageGenerationTool.__name__:
@@ -421,26 +436,12 @@ def construct_tools(
# Get the database tool model for SearchTool
search_tool_db_model = get_builtin_tool(db_session, SearchTool)
# Use the passed-in config if available, otherwise create a new one
if not search_tool_config:
search_tool_config = SearchToolConfig()
search_tool = SearchTool(
tool_id=search_tool_db_model.id,
emitter=emitter,
user=user,
persona=persona,
llm=llm,
document_index=document_index,
user_selected_filters=search_tool_config.user_selected_filters,
project_id_filter=search_tool_config.project_id_filter,
persona_id_filter=search_tool_config.persona_id_filter,
bypass_acl=search_tool_config.bypass_acl,
slack_context=search_tool_config.slack_context,
enable_slack_search=search_tool_config.enable_slack_search,
)
tool_dict[search_tool_db_model.id] = [search_tool]
tool_dict[search_tool_db_model.id] = [
_build_search_tool(search_tool_db_model.id, search_tool_config)
]
# Always inject MemoryTool when the user has the memory tool enabled,
# bypassing persona tool associations and allowed_tool_ids filtering

View File

@@ -51,6 +51,7 @@ from onyx.context.search.models import ChunkSearchRequest
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import PersonaSearchInfo
from onyx.context.search.models import SearchDocsResponse
from onyx.context.search.pipeline import merge_individual_chunks
from onyx.context.search.pipeline import search_pipeline
@@ -65,7 +66,6 @@ from onyx.db.federated import (
get_federated_connector_document_set_mappings_by_document_set_names,
)
from onyx.db.federated import list_federated_connector_oauth_tokens
from onyx.db.models import Persona
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.db.search_settings import get_current_search_settings
@@ -238,8 +238,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
emitter: Emitter,
# Used for ACLs and federated search, anonymous users only see public docs
user: User,
# Used for filter settings
persona: Persona,
# Pre-extracted persona search configuration
persona_search_info: PersonaSearchInfo,
llm: LLM,
document_index: DocumentIndex,
# Respecting user selections
@@ -258,7 +258,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
super().__init__(emitter=emitter)
self.user = user
self.persona = persona
self.persona_search_info = persona_search_info
self.llm = llm
self.document_index = document_index
self.user_selected_filters = user_selected_filters
@@ -289,7 +289,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
# Case 1: Slack bot context — requires a Slack federated connector
# linked via the persona's document sets
if self.slack_context:
document_set_names = [ds.name for ds in self.persona.document_sets]
document_set_names = self.persona_search_info.document_set_names
if not document_set_names:
logger.debug(
"Skipping Slack federated search: no document sets on persona"
@@ -463,7 +463,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
persona_id_filter=self.persona_id_filter,
document_index=self.document_index,
user=self.user,
persona=self.persona,
persona_search_info=self.persona_search_info,
acl_filters=acl_filters,
embedding_model=embedding_model,
prefetched_federated_retrieval_infos=federated_retrieval_infos,
@@ -587,15 +587,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
and self.user_selected_filters.source_type
else None
)
persona_document_sets = (
[ds.name for ds in self.persona.document_sets] if self.persona else None
)
federated_retrieval_infos = (
get_federated_retrieval_functions(
db_session=db_session,
user_id=self.user.id if self.user else None,
source_types=prefetch_source_types,
document_set_names=persona_document_sets,
document_set_names=self.persona_search_info.document_set_names,
)
or []
)

View File

@@ -549,7 +549,7 @@ mypy-extensions==1.0.0
# typing-inspect
nest-asyncio==1.6.0
# via onyx
nltk==3.9.3
nltk==3.9.4
# via unstructured
numpy==2.4.1
# via
@@ -752,7 +752,7 @@ pypandoc-binary==1.16.2
# via onyx
pyparsing==3.2.5
# via httplib2
pypdf==6.9.1
pypdf==6.9.2
# via
# onyx
# unstructured-client
@@ -861,7 +861,7 @@ regex==2025.11.3
# dateparser
# nltk
# tiktoken
requests==2.32.5
requests==2.33.0
# via
# atlassian-python-api
# braintrust

View File

@@ -410,7 +410,7 @@ release-tag==0.5.2
# via onyx
reorder-python-imports-black==3.14.0
# via onyx
requests==2.32.5
requests==2.33.0
# via
# cohere
# google-genai

View File

@@ -244,7 +244,7 @@ referencing==0.36.2
# jsonschema-specifications
regex==2025.11.3
# via tiktoken
requests==2.32.5
requests==2.33.0
# via
# cohere
# google-genai

View File

@@ -338,7 +338,7 @@ regex==2025.11.3
# via
# tiktoken
# transformers
requests==2.32.5
requests==2.33.0
# via
# cohere
# google-genai

View File

@@ -43,5 +43,8 @@ def test_eager_load_persona_loads_relationships(db_session: Session) -> None:
persona_unloaded = tmp.unloaded
assert "tools" not in persona_unloaded
assert "user_files" not in persona_unloaded
assert "document_sets" not in persona_unloaded
assert "attached_documents" not in persona_unloaded
assert "hierarchy_nodes" not in persona_unloaded
finally:
db_session.rollback()

View File

@@ -11,8 +11,8 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import ChunkSearchRequest
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import PersonaSearchInfo
from onyx.context.search.models import SearchDoc
from onyx.db.models import Persona
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
@@ -139,7 +139,7 @@ def use_mock_search_pipeline(
chunk_search_request: ChunkSearchRequest,
document_index: DocumentIndex, # noqa: ARG001
user: User | None, # noqa: ARG001
persona: Persona | None, # noqa: ARG001
persona_search_info: PersonaSearchInfo | None, # noqa: ARG001
db_session: Session | None = None, # noqa: ARG001
auto_detect_filters: bool = False, # noqa: ARG001
llm: LLM | None = None, # noqa: ARG001

View File

@@ -10,6 +10,9 @@ from typing import Any
from onyx.configs.constants import DocumentSource
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.schema import ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME
from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
from onyx.document_index.opensearch.search import DocumentQuery
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@@ -17,6 +20,12 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
ATTACHED_DOCUMENT_ID = "https://docs.google.com/document/d/test-doc-id"
HIERARCHY_NODE_ID = 42
PERSONA_ID = 7
KNOWLEDGE_FILTER_SCHEMA_FIELDS = {
DOCUMENT_ID_FIELD_NAME,
ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME,
DOCUMENT_SETS_FIELD_NAME,
PERSONAS_FIELD_NAME,
}
def _get_search_filters(
@@ -62,7 +71,26 @@ class TestAssistantKnowledgeFilter:
knowledge_filter = None
for clause in filter_clauses:
if "bool" in clause and "should" in clause["bool"]:
if clause["bool"].get("minimum_should_match") == 1:
if (
clause["bool"].get("minimum_should_match") == 1
and len(clause["bool"]["should"]) > 0
and (
(
clause["bool"]["should"][0].get("term", {}).keys()
and list(
clause["bool"]["should"][0].get("term", {}).keys()
)[0]
in KNOWLEDGE_FILTER_SCHEMA_FIELDS
)
or (
clause["bool"]["should"][0].get("terms", {}).keys()
and list(
clause["bool"]["should"][0].get("terms", {}).keys()
)[0]
in KNOWLEDGE_FILTER_SCHEMA_FIELDS
)
)
):
knowledge_filter = clause
break
@@ -96,7 +124,26 @@ class TestAssistantKnowledgeFilter:
knowledge_filter = None
for clause in filter_clauses:
if "bool" in clause and "should" in clause["bool"]:
if clause["bool"].get("minimum_should_match") == 1:
if (
clause["bool"].get("minimum_should_match") == 1
and len(clause["bool"]["should"]) > 0
and (
(
clause["bool"]["should"][0].get("term", {}).keys()
and list(
clause["bool"]["should"][0].get("term", {}).keys()
)[0]
in KNOWLEDGE_FILTER_SCHEMA_FIELDS
)
or (
clause["bool"]["should"][0].get("terms", {}).keys()
and list(
clause["bool"]["should"][0].get("terms", {}).keys()
)[0]
in KNOWLEDGE_FILTER_SCHEMA_FIELDS
)
)
):
knowledge_filter = clause
break
@@ -127,7 +174,26 @@ class TestAssistantKnowledgeFilter:
knowledge_filter = None
for clause in filter_clauses:
if "bool" in clause and "should" in clause["bool"]:
if clause["bool"].get("minimum_should_match") == 1:
if (
clause["bool"].get("minimum_should_match") == 1
and len(clause["bool"]["should"]) > 0
and (
(
clause["bool"]["should"][0].get("term", {}).keys()
and list(
clause["bool"]["should"][0].get("term", {}).keys()
)[0]
in KNOWLEDGE_FILTER_SCHEMA_FIELDS
)
or (
clause["bool"]["should"][0].get("terms", {}).keys()
and list(
clause["bool"]["should"][0].get("terms", {}).keys()
)[0]
in KNOWLEDGE_FILTER_SCHEMA_FIELDS
)
)
):
knowledge_filter = clause
break

View File

@@ -974,7 +974,7 @@ class TestOpenSearchClient:
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
# Index documents with different public/hidden, ACL, and tenant states.
docs = {
"public-doc": _create_test_document_chunk(
document_id="public-doc",
@@ -997,7 +997,7 @@ class TestOpenSearchClient:
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_emails=["user-a@example.com", "user-b@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
@@ -1044,7 +1044,10 @@ class TestOpenSearchClient:
# The user should only be able to see their private docs. tenant_id
# in this object is not relevant.
index_filters=IndexFilters(
access_control_list=[prefix_user_email("user-a@example.com")],
access_control_list=[
prefix_user_email("user-a@example.com"),
prefix_user_email("user-c@example.com"),
],
tenant_id=None,
),
include_hidden=False,
@@ -1661,7 +1664,7 @@ class TestOpenSearchClient:
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
# Index documents with different public/hidden, ACL, and tenant states.
docs = {
"public-doc": _create_test_document_chunk(
document_id="public-doc",
@@ -1684,7 +1687,7 @@ class TestOpenSearchClient:
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_emails=["user-a@example.com", "user-b@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
@@ -1746,7 +1749,10 @@ class TestOpenSearchClient:
# The user should only be able to see their private docs. tenant_id
# in this object is not relevant.
index_filters=IndexFilters(
access_control_list=[prefix_user_email("user-a@example.com")],
access_control_list=[
prefix_user_email("user-a@example.com"),
prefix_user_email("user-c@example.com"),
],
tenant_id=None,
),
include_hidden=False,
@@ -1805,7 +1811,7 @@ class TestOpenSearchClient:
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden and tenant states.
# Index documents with different public/hidden, ACL, and tenant states.
docs = {
"public-doc": _create_test_document_chunk(
document_id="public-doc",
@@ -1831,7 +1837,7 @@ class TestOpenSearchClient:
hidden=False,
tenant_state=tenant_x,
document_access=DocumentAccess.build(
user_emails=["user-a@example.com"],
user_emails=["user-a@example.com", "user-b@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
@@ -1879,7 +1885,10 @@ class TestOpenSearchClient:
# The user should only be able to see their private docs. tenant_id
# in this object is not relevant.
index_filters=IndexFilters(
access_control_list=[prefix_user_email("user-a@example.com")],
access_control_list=[
prefix_user_email("user-a@example.com"),
prefix_user_email("user-c@example.com"),
],
tenant_id=None,
),
include_hidden=False,

View File

@@ -80,6 +80,7 @@ def sharepoint_test_env_setup() -> Generator[SharepointTestEnvSetupTuple]:
source=DocumentSource.SHAREPOINT,
connector_specific_config={
"sites": sharepoint_sites.split(","),
"treat_sharing_link_as_public": True,
},
access_type=AccessType.SYNC, # Enable permission sync
user_performing_action=admin_user,

View File

@@ -8,6 +8,9 @@ import pytest
from ee.onyx.external_permissions.sharepoint.permission_utils import (
_enumerate_ad_groups_paginated,
)
from ee.onyx.external_permissions.sharepoint.permission_utils import (
_is_public_item,
)
from ee.onyx.external_permissions.sharepoint.permission_utils import (
_iter_graph_collection,
)
@@ -334,3 +337,143 @@ def test_site_page_url_not_duplicated(
ctx.web.get_file_by_server_relative_url.assert_called_once_with(
expected_relative_url
)
# ---------------------------------------------------------------------------
# _is_public_item sharing link visibility
# ---------------------------------------------------------------------------
def _make_permission(scope: str | None) -> MagicMock:
perm = MagicMock()
if scope is None:
perm.link = None
else:
perm.link = MagicMock()
perm.link.scope = scope
return perm
def _make_drive_item_with_permissions(
permissions: list[MagicMock],
) -> MagicMock:
drive_item = MagicMock()
drive_item.id = "item-123"
drive_item.permissions.get_all.return_value = permissions
return drive_item
@patch(f"{MODULE}.sleep_and_retry", side_effect=lambda query, _label: query)
def test_is_public_item_anonymous_link_when_enabled(
_mock_sleep: MagicMock,
) -> None:
drive_item = _make_drive_item_with_permissions([_make_permission("anonymous")])
assert _is_public_item(drive_item, treat_sharing_link_as_public=True) is True
@patch(f"{MODULE}.sleep_and_retry", side_effect=lambda query, _label: query)
def test_is_public_item_org_link_when_enabled(
_mock_sleep: MagicMock,
) -> None:
drive_item = _make_drive_item_with_permissions([_make_permission("organization")])
assert _is_public_item(drive_item, treat_sharing_link_as_public=True) is True
@patch(f"{MODULE}.sleep_and_retry", side_effect=lambda query, _label: query)
def test_is_public_item_anonymous_link_when_disabled(
_mock_sleep: MagicMock,
) -> None:
"""When the flag is off, anonymous links do NOT make the item public."""
drive_item = _make_drive_item_with_permissions([_make_permission("anonymous")])
assert _is_public_item(drive_item, treat_sharing_link_as_public=False) is False
@patch(f"{MODULE}.sleep_and_retry", side_effect=lambda query, _label: query)
def test_is_public_item_org_link_when_disabled(
_mock_sleep: MagicMock,
) -> None:
"""When the flag is off, org links do NOT make the item public."""
drive_item = _make_drive_item_with_permissions([_make_permission("organization")])
assert _is_public_item(drive_item, treat_sharing_link_as_public=False) is False
@patch(f"{MODULE}.sleep_and_retry", side_effect=lambda query, _label: query)
def test_is_public_item_no_sharing_links(
_mock_sleep: MagicMock,
) -> None:
"""User-level permissions only — not public even when flag is on."""
drive_item = _make_drive_item_with_permissions([_make_permission(None)])
assert _is_public_item(drive_item, treat_sharing_link_as_public=True) is False
@patch(f"{MODULE}.sleep_and_retry", side_effect=lambda query, _label: query)
def test_is_public_item_default_is_false(
_mock_sleep: MagicMock,
) -> None:
"""Default value of the flag is False, so sharing links are ignored."""
drive_item = _make_drive_item_with_permissions([_make_permission("anonymous")])
assert _is_public_item(drive_item) is False
def test_is_public_item_skips_api_call_when_disabled() -> None:
"""When the flag is off, the permissions API is never called."""
drive_item = MagicMock()
_is_public_item(drive_item, treat_sharing_link_as_public=False)
drive_item.permissions.get_all.assert_not_called()
# ---------------------------------------------------------------------------
# get_external_access_from_sharepoint sharing link integration
# ---------------------------------------------------------------------------
@patch(f"{MODULE}._is_public_item", return_value=True)
@patch(f"{MODULE}.sleep_and_retry")
def test_drive_item_public_when_sharing_link_enabled(
_mock_sleep: MagicMock,
_mock_is_public: MagicMock,
) -> None:
"""With treat_sharing_link_as_public=True, a public item returns is_public=True
and skips role-assignment resolution entirely."""
drive_item = MagicMock()
result = get_external_access_from_sharepoint(
client_context=MagicMock(),
graph_client=MagicMock(),
drive_name="Documents",
drive_item=drive_item,
site_page=None,
treat_sharing_link_as_public=True,
)
assert result.is_public is True
assert result.external_user_emails == set()
assert result.external_user_group_ids == set()
@patch(f"{MODULE}._get_groups_and_members_recursively")
@patch(f"{MODULE}.sleep_and_retry")
@patch(f"{MODULE}._is_public_item", return_value=False)
def test_drive_item_falls_through_when_sharing_link_disabled(
_mock_is_public: MagicMock,
mock_sleep: MagicMock, # noqa: ARG001
mock_recursive: MagicMock,
) -> None:
"""With treat_sharing_link_as_public=False, the function falls through to
role-assignment-based permission resolution."""
mock_recursive.return_value = GroupsResult(
groups_to_emails={"SiteMembers_abc": {"alice@contoso.com"}},
found_public_group=False,
)
result = get_external_access_from_sharepoint(
client_context=MagicMock(),
graph_client=MagicMock(),
drive_name="Documents",
drive_item=MagicMock(),
site_page=None,
treat_sharing_link_as_public=False,
)
assert result.is_public is False
assert len(result.external_user_group_ids) > 0

View File

@@ -0,0 +1,207 @@
"""Unit tests for multi-model streaming validation and DB helpers.
These are pure unit tests — no real database or LLM calls required.
The validation logic in run_multi_model_stream fires before any external
calls, so we can trigger it with lightweight mocks.
"""
from typing import Any
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from onyx.configs.constants import MessageType
from onyx.db.chat import set_preferred_response
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import SendMessageRequest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_request(**kwargs: Any) -> SendMessageRequest:
defaults: dict[str, Any] = {
"message": "hello",
"chat_session_id": uuid4(),
}
defaults.update(kwargs)
return SendMessageRequest(**defaults)
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
return LLMOverride(model_provider=provider, model_version=version)
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
"""Advance the generator one step to trigger early validation."""
from onyx.chat.process_message import run_multi_model_stream
user = MagicMock()
user.is_anonymous = False
user.email = "test@example.com"
db = MagicMock()
gen = run_multi_model_stream(req, user, db, overrides)
# Calling next() executes until the first yield OR raises.
# Validation errors are raised before any yield.
next(gen)
# ---------------------------------------------------------------------------
# run_multi_model_stream — validation
# ---------------------------------------------------------------------------
class TestRunMultiModelStreamValidation:
def test_single_override_raises(self) -> None:
"""Exactly 1 override is not multi-model — must raise."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
def test_four_overrides_raises(self) -> None:
"""4 overrides exceeds maximum — must raise."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
def test_zero_overrides_raises(self) -> None:
"""Empty override list raises."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [])
def test_deep_research_raises(self) -> None:
"""deep_research=True is incompatible with multi-model."""
req = _make_request(deep_research=True)
with pytest.raises(ValueError, match="not supported"):
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
def test_exactly_two_overrides_is_minimum(self) -> None:
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
req = _make_request()
# 1 override must fail
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
try:
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
except ValueError as exc:
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
except Exception:
pass # Any other error means validation passed
# ---------------------------------------------------------------------------
# set_preferred_response — validation (mocked db)
# ---------------------------------------------------------------------------
class TestSetPreferredResponseValidation:
def test_user_message_not_found(self) -> None:
db = MagicMock()
db.get.return_value = None
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=999, preferred_assistant_message_id=1
)
def test_wrong_message_type(self) -> None:
"""Cannot set preferred response on a non-USER message."""
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.ASSISTANT # wrong type
db.get.return_value = user_msg
with pytest.raises(ValueError, match="not a user message"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_message_not_found(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
# First call returns user_msg, second call (for assistant) returns None
db.get.side_effect = [user_msg, None]
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_not_child_of_user(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 999 # different parent
db.get.side_effect = [user_msg, assistant_msg]
with pytest.raises(ValueError, match="not a child"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_valid_call_sets_preferred_response_id(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 1 # correct parent
db.get.side_effect = [user_msg, assistant_msg]
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
assert user_msg.preferred_response_id == 2
assert user_msg.latest_child_message_id == 2
# ---------------------------------------------------------------------------
# LLMOverride — display_name field
# ---------------------------------------------------------------------------
class TestLLMOverrideDisplayName:
def test_display_name_defaults_none(self) -> None:
override = LLMOverride(model_provider="openai", model_version="gpt-4")
assert override.display_name is None
def test_display_name_set(self) -> None:
override = LLMOverride(
model_provider="openai",
model_version="gpt-4",
display_name="GPT-4 Turbo",
)
assert override.display_name == "GPT-4 Turbo"
def test_display_name_serializes(self) -> None:
override = LLMOverride(
model_provider="anthropic",
model_version="claude-opus-4-6",
display_name="Claude Opus",
)
d = override.model_dump()
assert d["display_name"] == "Claude Opus"

View File

@@ -0,0 +1,146 @@
"""Unit tests for multi-model answer generation types.
Tests cover:
- Placement.model_index serialization
- MultiModelMessageResponseIDInfo round-trip
- SendMessageRequest.llm_overrides backward compatibility
- ChatMessageDetail new fields
"""
from datetime import datetime
from datetime import timezone
from uuid import uuid4
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import ModelResponseSlot
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
class TestPlacementModelIndex:
def test_default_none(self) -> None:
p = Placement(turn_index=0)
assert p.model_index is None
def test_set_value(self) -> None:
p = Placement(turn_index=0, model_index=2)
assert p.model_index == 2
def test_serializes(self) -> None:
p = Placement(turn_index=0, tab_index=1, model_index=1)
d = p.model_dump()
assert d["model_index"] == 1
def test_none_excluded_when_default(self) -> None:
p = Placement(turn_index=0)
d = p.model_dump()
assert d["model_index"] is None
class TestMultiModelMessageResponseIDInfo:
def test_round_trip(self) -> None:
info = MultiModelMessageResponseIDInfo(
user_message_id=42,
responses=[
ModelResponseSlot(message_id=43, model_name="gpt-4"),
ModelResponseSlot(message_id=44, model_name="claude-opus"),
ModelResponseSlot(message_id=45, model_name="gemini-pro"),
],
)
d = info.model_dump()
restored = MultiModelMessageResponseIDInfo(**d)
assert restored.user_message_id == 42
assert [s.message_id for s in restored.responses] == [43, 44, 45]
assert [s.model_name for s in restored.responses] == [
"gpt-4",
"claude-opus",
"gemini-pro",
]
def test_null_user_message_id(self) -> None:
info = MultiModelMessageResponseIDInfo(
user_message_id=None,
responses=[
ModelResponseSlot(message_id=1, model_name="a"),
ModelResponseSlot(message_id=2, model_name="b"),
],
)
assert info.user_message_id is None
class TestSendMessageRequestOverrides:
def test_llm_overrides_default_none(self) -> None:
req = SendMessageRequest(
message="hello",
chat_session_id=uuid4(),
)
assert req.llm_overrides is None
def test_llm_overrides_accepts_list(self) -> None:
overrides = [
LLMOverride(model_provider="openai", model_version="gpt-4"),
LLMOverride(model_provider="anthropic", model_version="claude-opus"),
]
req = SendMessageRequest(
message="hello",
chat_session_id=uuid4(),
llm_overrides=overrides,
)
assert req.llm_overrides is not None
assert len(req.llm_overrides) == 2
def test_backward_compat_single_override(self) -> None:
req = SendMessageRequest(
message="hello",
chat_session_id=uuid4(),
llm_override=LLMOverride(model_provider="openai", model_version="gpt-4"),
)
assert req.llm_override is not None
assert req.llm_overrides is None
class TestChatMessageDetailMultiModel:
def test_defaults_none(self) -> None:
from onyx.configs.constants import MessageType
detail = ChatMessageDetail(
message_id=1,
message="hello",
message_type=MessageType.ASSISTANT,
time_sent=datetime(2026, 3, 22, tzinfo=timezone.utc),
files=[],
)
assert detail.preferred_response_id is None
assert detail.model_display_name is None
def test_set_values(self) -> None:
from onyx.configs.constants import MessageType
detail = ChatMessageDetail(
message_id=1,
message="hello",
message_type=MessageType.USER,
time_sent=datetime(2026, 3, 22, tzinfo=timezone.utc),
files=[],
preferred_response_id=42,
model_display_name="GPT-4",
)
assert detail.preferred_response_id == 42
assert detail.model_display_name == "GPT-4"
def test_serializes(self) -> None:
from onyx.configs.constants import MessageType
detail = ChatMessageDetail(
message_id=1,
message="hello",
message_type=MessageType.ASSISTANT,
time_sent=datetime(2026, 3, 22, tzinfo=timezone.utc),
files=[],
model_display_name="Claude Opus",
)
d = detail.model_dump()
assert d["model_display_name"] == "Claude Opus"
assert d["preferred_response_id"] is None

View File

@@ -60,4 +60,4 @@ def test_non_rate_limit_error(mock_confluence_call: Mock) -> None:
with pytest.raises(HTTPError):
handled_call()
assert mock_confluence_call.call_count == 1
assert mock_confluence_call.call_count == 5

View File

@@ -0,0 +1,321 @@
"""Unit tests for Notion connector data source API migration.
Tests the new data source discovery + querying flow and the
data_source_id -> database_id parent resolution.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from requests.exceptions import HTTPError
from onyx.connectors.notion.connector import NotionConnector
from onyx.connectors.notion.connector import NotionDataSource
from onyx.connectors.notion.connector import NotionPage
def _make_connector() -> NotionConnector:
connector = NotionConnector()
connector.load_credentials({"notion_integration_token": "fake-token"})
return connector
def _mock_response(json_data: dict, status_code: int = 200) -> MagicMock:
resp = MagicMock()
resp.json.return_value = json_data
resp.status_code = status_code
if status_code >= 400:
resp.raise_for_status.side_effect = HTTPError(
f"HTTP {status_code}", response=resp
)
else:
resp.raise_for_status.return_value = None
return resp
class TestFetchDataSourcesForDatabase:
def test_multi_source_database(self) -> None:
connector = _make_connector()
resp = _mock_response(
{
"object": "database",
"id": "db-1",
"data_sources": [
{"id": "ds-1", "name": "Source A"},
{"id": "ds-2", "name": "Source B"},
],
}
)
with patch(
"onyx.connectors.notion.connector.rl_requests.get", return_value=resp
):
result = connector._fetch_data_sources_for_database("db-1")
assert result == [
NotionDataSource(id="ds-1", name="Source A"),
NotionDataSource(id="ds-2", name="Source B"),
]
def test_single_source_database(self) -> None:
connector = _make_connector()
resp = _mock_response(
{
"object": "database",
"id": "db-1",
"data_sources": [{"id": "ds-1", "name": "Only Source"}],
}
)
with patch(
"onyx.connectors.notion.connector.rl_requests.get", return_value=resp
):
result = connector._fetch_data_sources_for_database("db-1")
assert result == [NotionDataSource(id="ds-1", name="Only Source")]
def test_404_returns_empty(self) -> None:
connector = _make_connector()
resp = _mock_response({"object": "error"}, status_code=404)
with patch(
"onyx.connectors.notion.connector.rl_requests.get", return_value=resp
):
result = connector._fetch_data_sources_for_database("db-missing")
assert result == []
class TestFetchDataSource:
def test_query_returns_pages(self) -> None:
connector = _make_connector()
resp = _mock_response(
{
"results": [
{
"object": "page",
"id": "page-1",
"properties": {"Name": {"type": "title", "title": []}},
}
],
"next_cursor": None,
}
)
with patch(
"onyx.connectors.notion.connector.rl_requests.post", return_value=resp
):
result = connector._fetch_data_source("ds-1")
assert len(result["results"]) == 1
assert result["results"][0]["id"] == "page-1"
assert result["next_cursor"] is None
def test_404_returns_empty_results(self) -> None:
connector = _make_connector()
resp = _mock_response({"object": "error"}, status_code=404)
with patch(
"onyx.connectors.notion.connector.rl_requests.post", return_value=resp
):
result = connector._fetch_data_source("ds-missing")
assert result == {"results": [], "next_cursor": None}
class TestGetParentRawId:
def test_database_id_parent(self) -> None:
connector = _make_connector()
parent = {"type": "database_id", "database_id": "db-1"}
assert connector._get_parent_raw_id(parent) == "db-1"
def test_data_source_id_with_mapping(self) -> None:
connector = _make_connector()
connector._data_source_to_database_map["ds-1"] = "db-1"
parent = {"type": "data_source_id", "data_source_id": "ds-1"}
assert connector._get_parent_raw_id(parent) == "db-1"
def test_data_source_id_without_mapping_falls_back(self) -> None:
connector = _make_connector()
connector.workspace_id = "ws-1"
parent = {"type": "data_source_id", "data_source_id": "ds-unknown"}
assert connector._get_parent_raw_id(parent) == "ws-1"
def test_workspace_parent(self) -> None:
connector = _make_connector()
connector.workspace_id = "ws-1"
parent = {"type": "workspace"}
assert connector._get_parent_raw_id(parent) == "ws-1"
def test_page_id_parent(self) -> None:
connector = _make_connector()
parent = {"type": "page_id", "page_id": "page-1"}
assert connector._get_parent_raw_id(parent) == "page-1"
def test_block_id_parent_with_mapping(self) -> None:
connector = _make_connector()
connector.workspace_id = "ws-1"
connector._child_page_parent_map["inline-page-1"] = "containing-page-1"
parent = {"type": "block_id"}
assert (
connector._get_parent_raw_id(parent, page_id="inline-page-1")
== "containing-page-1"
)
def test_block_id_parent_without_mapping_falls_back(self) -> None:
connector = _make_connector()
connector.workspace_id = "ws-1"
parent = {"type": "block_id"}
assert connector._get_parent_raw_id(parent, page_id="unknown-page") == "ws-1"
def test_none_parent_defaults_to_workspace(self) -> None:
connector = _make_connector()
connector.workspace_id = "ws-1"
assert connector._get_parent_raw_id(None) == "ws-1"
class TestReadPagesFromDatabaseMultiSource:
def test_queries_all_data_sources(self) -> None:
connector = _make_connector()
connector.workspace_id = "ws-1"
with (
patch.object(
connector,
"_fetch_data_sources_for_database",
return_value=[
NotionDataSource(id="ds-1", name="Source A"),
NotionDataSource(id="ds-2", name="Source B"),
],
),
patch.object(
connector,
"_fetch_data_source",
return_value={"results": [], "next_cursor": None},
) as mock_fetch_ds,
):
result = connector._read_pages_from_database("db-1")
assert mock_fetch_ds.call_count == 2
mock_fetch_ds.assert_any_call("ds-1", None)
mock_fetch_ds.assert_any_call("ds-2", None)
assert connector._data_source_to_database_map["ds-1"] == "db-1"
assert connector._data_source_to_database_map["ds-2"] == "db-1"
assert result.blocks == []
assert result.child_page_ids == []
assert len(result.hierarchy_nodes) == 1
assert result.hierarchy_nodes[0].raw_node_id == "db-1"
def test_collects_pages_from_all_sources(self) -> None:
connector = _make_connector()
connector.workspace_id = "ws-1"
connector.recursive_index_enabled = True
ds1_results = {
"results": [{"object": "page", "id": "page-from-ds1", "properties": {}}],
"next_cursor": None,
}
ds2_results = {
"results": [{"object": "page", "id": "page-from-ds2", "properties": {}}],
"next_cursor": None,
}
with (
patch.object(
connector,
"_fetch_data_sources_for_database",
return_value=[
NotionDataSource(id="ds-1", name="Source A"),
NotionDataSource(id="ds-2", name="Source B"),
],
),
patch.object(
connector,
"_fetch_data_source",
side_effect=[ds1_results, ds2_results],
),
):
result = connector._read_pages_from_database("db-1")
assert "page-from-ds1" in result.child_page_ids
assert "page-from-ds2" in result.child_page_ids
def test_pagination_across_pages(self) -> None:
connector = _make_connector()
connector.workspace_id = "ws-1"
connector.recursive_index_enabled = True
page1 = {
"results": [{"object": "page", "id": "page-1", "properties": {}}],
"next_cursor": "cursor-abc",
}
page2 = {
"results": [{"object": "page", "id": "page-2", "properties": {}}],
"next_cursor": None,
}
with (
patch.object(
connector,
"_fetch_data_sources_for_database",
return_value=[NotionDataSource(id="ds-1", name="Source A")],
),
patch.object(
connector,
"_fetch_data_source",
side_effect=[page1, page2],
) as mock_fetch_ds,
):
result = connector._read_pages_from_database("db-1")
assert mock_fetch_ds.call_count == 2
mock_fetch_ds.assert_any_call("ds-1", None)
mock_fetch_ds.assert_any_call("ds-1", "cursor-abc")
assert "page-1" in result.child_page_ids
assert "page-2" in result.child_page_ids
class TestInTrashField:
def test_notion_page_accepts_in_trash(self) -> None:
page = NotionPage(
id="page-1",
created_time="2026-01-01T00:00:00.000Z",
last_edited_time="2026-01-01T00:00:00.000Z",
in_trash=False,
properties={},
url="https://notion.so/page-1",
)
assert page.in_trash is False
def test_notion_page_in_trash_true(self) -> None:
page = NotionPage(
id="page-1",
created_time="2026-01-01T00:00:00.000Z",
last_edited_time="2026-01-01T00:00:00.000Z",
in_trash=True,
properties={},
url="https://notion.so/page-1",
)
assert page.in_trash is True
class TestFetchDatabaseAsPage:
def test_handles_missing_properties(self) -> None:
connector = _make_connector()
resp = _mock_response(
{
"object": "database",
"id": "db-1",
"created_time": "2026-01-01T00:00:00.000Z",
"last_edited_time": "2026-01-01T00:00:00.000Z",
"in_trash": False,
"url": "https://notion.so/db-1",
"title": [{"text": {"content": "My DB"}, "plain_text": "My DB"}],
"data_sources": [{"id": "ds-1", "name": "Source"}],
}
)
with patch(
"onyx.connectors.notion.connector.rl_requests.get", return_value=resp
):
page = connector._fetch_database_as_page("db-1")
assert page.id == "db-1"
assert page.database_name == "My DB"
assert page.properties == {}

View File

@@ -145,6 +145,7 @@ def _mock_convert(monkeypatch: pytest.MonkeyPatch) -> None:
include_permissions: bool = False, # noqa: ARG001
parent_hierarchy_raw_node_id: str | None = None, # noqa: ARG001
access_token: str | None = None, # noqa: ARG001
treat_sharing_link_as_public: bool = False, # noqa: ARG001
) -> Document:
return _make_document(driveitem)

View File

@@ -0,0 +1,215 @@
from __future__ import annotations
import pytest
from onyx.connectors.sharepoint.connector import _build_item_relative_path
from onyx.connectors.sharepoint.connector import _is_path_excluded
from onyx.connectors.sharepoint.connector import _is_site_excluded
from onyx.connectors.sharepoint.connector import DriveItemData
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.connectors.sharepoint.connector import SiteDescriptor
class TestIsSiteExcluded:
def test_exact_match(self) -> None:
assert _is_site_excluded(
"https://contoso.sharepoint.com/sites/archive",
["https://contoso.sharepoint.com/sites/archive"],
)
def test_trailing_slash_mismatch(self) -> None:
assert _is_site_excluded(
"https://contoso.sharepoint.com/sites/archive/",
["https://contoso.sharepoint.com/sites/archive"],
)
def test_glob_wildcard(self) -> None:
assert _is_site_excluded(
"https://contoso.sharepoint.com/sites/archive-2024",
["*/sites/archive-*"],
)
def test_no_match(self) -> None:
assert not _is_site_excluded(
"https://contoso.sharepoint.com/sites/engineering",
["https://contoso.sharepoint.com/sites/archive"],
)
def test_empty_patterns(self) -> None:
assert not _is_site_excluded(
"https://contoso.sharepoint.com/sites/engineering",
[],
)
def test_multiple_patterns(self) -> None:
patterns = [
"*/sites/archive-*",
"*/sites/hr-confidential",
]
assert _is_site_excluded(
"https://contoso.sharepoint.com/sites/hr-confidential",
patterns,
)
assert not _is_site_excluded(
"https://contoso.sharepoint.com/sites/engineering",
patterns,
)
class TestIsPathExcluded:
def test_filename_glob(self) -> None:
assert _is_path_excluded("Engineering/report.tmp", ["*.tmp"])
def test_filename_only(self) -> None:
assert _is_path_excluded("report.tmp", ["*.tmp"])
def test_office_lock_files(self) -> None:
assert _is_path_excluded("Docs/~$document.docx", ["~$*"])
def test_folder_glob(self) -> None:
assert _is_path_excluded("Archive/old/report.docx", ["Archive/*"])
def test_nested_folder_glob(self) -> None:
assert _is_path_excluded("Projects/Archive/report.docx", ["*/Archive/*"])
def test_no_match(self) -> None:
assert not _is_path_excluded("Engineering/report.docx", ["*.tmp"])
def test_empty_patterns(self) -> None:
assert not _is_path_excluded("anything.docx", [])
def test_multiple_patterns(self) -> None:
patterns = ["*.tmp", "~$*", "Archive/*"]
assert _is_path_excluded("test.tmp", patterns)
assert _is_path_excluded("~$doc.docx", patterns)
assert _is_path_excluded("Archive/old.pdf", patterns)
assert not _is_path_excluded("Engineering/report.docx", patterns)
class TestBuildItemRelativePath:
def test_with_folder(self) -> None:
assert (
_build_item_relative_path(
"/drives/abc/root:/Engineering/API", "report.docx"
)
== "Engineering/API/report.docx"
)
def test_root_level(self) -> None:
assert (
_build_item_relative_path("/drives/abc/root:", "report.docx")
== "report.docx"
)
def test_none_parent(self) -> None:
assert _build_item_relative_path(None, "report.docx") == "report.docx"
def test_percent_encoded_folder(self) -> None:
assert (
_build_item_relative_path("/drives/abc/root:/My%20Documents", "report.docx")
== "My Documents/report.docx"
)
def test_no_root_marker(self) -> None:
assert _build_item_relative_path("/drives/abc", "report.docx") == "report.docx"
class TestFilterExcludedSites:
def test_filters_matching_sites(self) -> None:
connector = SharepointConnector(
excluded_sites=["*/sites/archive"],
)
descriptors = [
SiteDescriptor(
url="https://t.sharepoint.com/sites/archive",
drive_name=None,
folder_path=None,
),
SiteDescriptor(
url="https://t.sharepoint.com/sites/engineering",
drive_name=None,
folder_path=None,
),
]
result = connector._filter_excluded_sites(descriptors)
assert len(result) == 1
assert result[0].url == "https://t.sharepoint.com/sites/engineering"
def test_empty_excluded_returns_all(self) -> None:
connector = SharepointConnector(excluded_sites=[])
descriptors = [
SiteDescriptor(
url="https://t.sharepoint.com/sites/a",
drive_name=None,
folder_path=None,
),
SiteDescriptor(
url="https://t.sharepoint.com/sites/b",
drive_name=None,
folder_path=None,
),
]
result = connector._filter_excluded_sites(descriptors)
assert len(result) == 2
class TestIsDriveitemExcluded:
def test_excluded_by_extension(self) -> None:
connector = SharepointConnector(excluded_paths=["*.tmp"])
item = DriveItemData(
id="1",
name="file.tmp",
web_url="https://example.com/file.tmp",
parent_reference_path="/drives/abc/root:/Docs",
)
assert connector._is_driveitem_excluded(item)
def test_not_excluded(self) -> None:
connector = SharepointConnector(excluded_paths=["*.tmp"])
item = DriveItemData(
id="1",
name="file.docx",
web_url="https://example.com/file.docx",
parent_reference_path="/drives/abc/root:/Docs",
)
assert not connector._is_driveitem_excluded(item)
def test_no_patterns_never_excludes(self) -> None:
connector = SharepointConnector(excluded_paths=[])
item = DriveItemData(
id="1",
name="file.tmp",
web_url="https://example.com/file.tmp",
parent_reference_path="/drives/abc/root:/Docs",
)
assert not connector._is_driveitem_excluded(item)
def test_folder_pattern(self) -> None:
connector = SharepointConnector(excluded_paths=["Archive/*"])
item = DriveItemData(
id="1",
name="old.pdf",
web_url="https://example.com/old.pdf",
parent_reference_path="/drives/abc/root:/Archive",
)
assert connector._is_driveitem_excluded(item)
@pytest.mark.parametrize(
"whitespace_pattern",
["", " ", "\t"],
)
def test_whitespace_patterns_ignored(self, whitespace_pattern: str) -> None:
connector = SharepointConnector(excluded_paths=[whitespace_pattern])
assert connector.excluded_paths == []
def test_whitespace_padded_patterns_are_trimmed(self) -> None:
connector = SharepointConnector(excluded_paths=[" *.tmp ", " Archive/* "])
assert connector.excluded_paths == ["*.tmp", "Archive/*"]
item = DriveItemData(
id="1",
name="file.tmp",
web_url="https://example.com/file.tmp",
parent_reference_path="/drives/abc/root:/Docs",
)
assert connector._is_driveitem_excluded(item)

View File

@@ -211,6 +211,7 @@ def test_load_from_checkpoint_maps_drive_name(monkeypatch: pytest.MonkeyPatch) -
include_permissions: bool, # noqa: ARG001
parent_hierarchy_raw_node_id: str | None = None, # noqa: ARG001
access_token: str | None = None, # noqa: ARG001
treat_sharing_link_as_public: bool = False, # noqa: ARG001
) -> Document:
captured_drive_names.append(drive_name)
return Document(

View File

@@ -0,0 +1,153 @@
"""Tests for generic Celery task lifecycle Prometheus metrics."""
from collections.abc import Iterator
from unittest.mock import MagicMock
import pytest
from onyx.server.metrics.celery_task_metrics import _task_start_times
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
from onyx.server.metrics.celery_task_metrics import TASK_COMPLETED
from onyx.server.metrics.celery_task_metrics import TASK_DURATION
from onyx.server.metrics.celery_task_metrics import TASK_STARTED
from onyx.server.metrics.celery_task_metrics import TASKS_ACTIVE
@pytest.fixture(autouse=True)
def reset_metrics() -> Iterator[None]:
"""Clear metric state between tests."""
_task_start_times.clear()
yield
_task_start_times.clear()
def _make_task(name: str = "test_task", queue: str = "test_queue") -> MagicMock:
task = MagicMock()
task.name = name
task.request = MagicMock()
task.request.delivery_info = {"routing_key": queue}
return task
class TestCeleryTaskPrerun:
def test_increments_started_and_active(self) -> None:
task = _make_task()
before_started = TASK_STARTED.labels(
task_name="test_task", queue="test_queue"
)._value.get()
before_active = TASKS_ACTIVE.labels(
task_name="test_task", queue="test_queue"
)._value.get()
on_celery_task_prerun("task-1", task)
after_started = TASK_STARTED.labels(
task_name="test_task", queue="test_queue"
)._value.get()
after_active = TASKS_ACTIVE.labels(
task_name="test_task", queue="test_queue"
)._value.get()
assert after_started == before_started + 1
assert after_active == before_active + 1
def test_records_start_time(self) -> None:
task = _make_task()
on_celery_task_prerun("task-1", task)
assert "task-1" in _task_start_times
def test_noop_when_task_is_none(self) -> None:
on_celery_task_prerun("task-1", None)
assert "task-1" not in _task_start_times
def test_noop_when_task_id_is_none(self) -> None:
task = _make_task()
on_celery_task_prerun(None, task)
# Should not crash
def test_handles_missing_delivery_info(self) -> None:
task = _make_task()
task.request.delivery_info = None
on_celery_task_prerun("task-1", task)
assert "task-1" in _task_start_times
class TestCeleryTaskPostrun:
def test_increments_completed_success(self) -> None:
task = _make_task()
on_celery_task_prerun("task-1", task)
before = TASK_COMPLETED.labels(
task_name="test_task", queue="test_queue", outcome="success"
)._value.get()
on_celery_task_postrun("task-1", task, "SUCCESS")
after = TASK_COMPLETED.labels(
task_name="test_task", queue="test_queue", outcome="success"
)._value.get()
assert after == before + 1
def test_increments_completed_failure(self) -> None:
task = _make_task()
on_celery_task_prerun("task-1", task)
before = TASK_COMPLETED.labels(
task_name="test_task", queue="test_queue", outcome="failure"
)._value.get()
on_celery_task_postrun("task-1", task, "FAILURE")
after = TASK_COMPLETED.labels(
task_name="test_task", queue="test_queue", outcome="failure"
)._value.get()
assert after == before + 1
def test_decrements_active(self) -> None:
task = _make_task()
on_celery_task_prerun("task-1", task)
active_before = TASKS_ACTIVE.labels(
task_name="test_task", queue="test_queue"
)._value.get()
on_celery_task_postrun("task-1", task, "SUCCESS")
active_after = TASKS_ACTIVE.labels(
task_name="test_task", queue="test_queue"
)._value.get()
assert active_after == active_before - 1
def test_observes_duration(self) -> None:
task = _make_task()
on_celery_task_prerun("task-1", task)
before_count = TASK_DURATION.labels(
task_name="test_task", queue="test_queue"
)._sum.get()
on_celery_task_postrun("task-1", task, "SUCCESS")
after_count = TASK_DURATION.labels(
task_name="test_task", queue="test_queue"
)._sum.get()
# Duration should have increased (at least slightly)
assert after_count > before_count
def test_cleans_up_start_time(self) -> None:
task = _make_task()
on_celery_task_prerun("task-1", task)
assert "task-1" in _task_start_times
on_celery_task_postrun("task-1", task, "SUCCESS")
assert "task-1" not in _task_start_times
def test_noop_when_task_is_none(self) -> None:
on_celery_task_postrun("task-1", None, "SUCCESS")
def test_handles_missing_start_time(self) -> None:
"""Postrun without prerun should not crash."""
task = _make_task()
on_celery_task_postrun("task-1", task, "SUCCESS")
# Should not raise

View File

@@ -0,0 +1,359 @@
"""Tests for indexing pipeline Prometheus collectors."""
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
class TestQueueDepthCollector:
def test_returns_empty_when_factory_not_set(self) -> None:
collector = QueueDepthCollector()
assert collector.collect() == []
def test_returns_empty_describe(self) -> None:
collector = QueueDepthCollector()
assert collector.describe() == []
def test_collects_queue_depths(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
with (
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
return_value=5,
),
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
return_value={"task-1", "task-2"},
),
):
families = collector.collect()
assert len(families) == 3
depth_family = families[0]
unacked_family = families[1]
age_family = families[2]
assert depth_family.name == "onyx_queue_depth"
assert len(depth_family.samples) > 0
for sample in depth_family.samples:
assert sample.value == 5
assert unacked_family.name == "onyx_queue_unacked"
unacked_labels = {s.labels["queue"] for s in unacked_family.samples}
assert "docfetching" in unacked_labels
assert "docprocessing" in unacked_labels
assert age_family.name == "onyx_queue_oldest_task_age_seconds"
for sample in unacked_family.samples:
assert sample.value == 2
def test_handles_redis_error_gracefully(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
with patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
side_effect=Exception("connection lost"),
):
families = collector.collect()
# Returns stale cache (empty on first call)
assert families == []
def test_caching_returns_stale_within_ttl(self) -> None:
collector = QueueDepthCollector(cache_ttl=60)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
with (
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
return_value=5,
),
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
return_value=set(),
),
):
first = collector.collect()
# Second call within TTL should return cached result without calling Redis
with patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
side_effect=Exception("should not be called"),
):
second = collector.collect()
assert first is second # Same object, from cache
def test_factory_called_each_scrape(self) -> None:
"""Verify the Redis factory is called on each fresh collect, not cached."""
collector = QueueDepthCollector(cache_ttl=0)
factory = MagicMock(return_value=MagicMock())
collector.set_redis_factory(factory)
with (
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
return_value=0,
),
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
return_value=set(),
),
):
collector.collect()
collector.collect()
assert factory.call_count == 2
def test_error_returns_stale_cache(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
# First call succeeds
with (
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
return_value=10,
),
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
return_value=set(),
),
):
good_result = collector.collect()
assert len(good_result) == 3
assert good_result[0].samples[0].value == 10
# Second call fails — should return stale cache, not empty
with patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
side_effect=Exception("Redis down"),
):
stale_result = collector.collect()
assert stale_result is good_result
class TestIndexAttemptCollector:
def test_returns_empty_when_not_configured(self) -> None:
collector = IndexAttemptCollector()
assert collector.collect() == []
def test_returns_empty_describe(self) -> None:
collector = IndexAttemptCollector()
assert collector.describe() == []
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
def test_collects_index_attempts(
self,
mock_get_session: MagicMock,
mock_get_tenants: MagicMock,
) -> None:
collector = IndexAttemptCollector(cache_ttl=0)
collector.configure()
mock_get_tenants.return_value = ["public"]
mock_session = MagicMock()
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
from onyx.db.enums import IndexingStatus
mock_row = (
IndexingStatus.IN_PROGRESS,
MagicMock(value="web"),
81,
"Table Tennis Blade Guide",
2,
)
mock_session.query.return_value.join.return_value.join.return_value.filter.return_value.group_by.return_value.all.return_value = [
mock_row
]
families = collector.collect()
assert len(families) == 1
assert families[0].name == "onyx_index_attempts_active"
assert len(families[0].samples) == 1
sample = families[0].samples[0]
assert sample.labels == {
"status": "in_progress",
"source": "web",
"tenant_id": "public",
"connector_name": "Table Tennis Blade Guide",
"cc_pair_id": "81",
}
assert sample.value == 2
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
def test_handles_db_error_gracefully(
self,
mock_get_tenants: MagicMock,
) -> None:
collector = IndexAttemptCollector(cache_ttl=0)
collector.configure()
mock_get_tenants.side_effect = Exception("DB down")
families = collector.collect()
# No stale cache, so returns empty
assert families == []
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
def test_skips_none_tenant_ids(
self,
mock_get_tenants: MagicMock,
) -> None:
collector = IndexAttemptCollector(cache_ttl=0)
collector.configure()
mock_get_tenants.return_value = [None]
families = collector.collect()
assert len(families) == 1 # Returns the gauge family, just with no samples
assert len(families[0].samples) == 0
class TestConnectorHealthCollector:
def test_returns_empty_when_not_configured(self) -> None:
collector = ConnectorHealthCollector()
assert collector.collect() == []
def test_returns_empty_describe(self) -> None:
collector = ConnectorHealthCollector()
assert collector.describe() == []
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
def test_collects_connector_health(
self,
mock_get_session: MagicMock,
mock_get_tenants: MagicMock,
) -> None:
collector = ConnectorHealthCollector(cache_ttl=0)
collector.configure()
mock_get_tenants.return_value = ["public"]
mock_session = MagicMock()
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
now = datetime.now(tz=timezone.utc)
last_success = now - timedelta(hours=2)
mock_status = MagicMock(value="ACTIVE")
mock_source = MagicMock(value="google_drive")
# Row: (id, status, in_error, last_success, name, source)
mock_row = (
42,
mock_status,
True, # in_repeated_error_state
last_success,
"My GDrive Connector",
mock_source,
)
mock_session.query.return_value.join.return_value.all.return_value = [mock_row]
# Mock the index attempt queries (error counts + docs counts)
mock_session.query.return_value.filter.return_value.group_by.return_value.all.return_value = (
[]
)
families = collector.collect()
assert len(families) == 6
names = {f.name for f in families}
assert names == {
"onyx_connector_last_success_age_seconds",
"onyx_connector_in_error_state",
"onyx_connectors_by_status",
"onyx_connectors_in_error_total",
"onyx_connector_docs_indexed",
"onyx_connector_error_count",
}
staleness = next(
f for f in families if f.name == "onyx_connector_last_success_age_seconds"
)
assert len(staleness.samples) == 1
assert staleness.samples[0].value == pytest.approx(7200, abs=5)
error_state = next(
f for f in families if f.name == "onyx_connector_in_error_state"
)
assert error_state.samples[0].value == 1.0
by_status = next(f for f in families if f.name == "onyx_connectors_by_status")
assert by_status.samples[0].labels == {
"tenant_id": "public",
"status": "ACTIVE",
}
assert by_status.samples[0].value == 1
error_total = next(
f for f in families if f.name == "onyx_connectors_in_error_total"
)
assert error_total.samples[0].value == 1
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
def test_skips_staleness_when_no_last_success(
self,
mock_get_session: MagicMock,
mock_get_tenants: MagicMock,
) -> None:
collector = ConnectorHealthCollector(cache_ttl=0)
collector.configure()
mock_get_tenants.return_value = ["public"]
mock_session = MagicMock()
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
mock_status = MagicMock(value="INITIAL_INDEXING")
mock_source = MagicMock(value="slack")
mock_row = (
10,
mock_status,
False,
None, # no last_successful_index_time
0,
mock_source,
)
mock_session.query.return_value.join.return_value.all.return_value = [mock_row]
families = collector.collect()
staleness = next(
f for f in families if f.name == "onyx_connector_last_success_age_seconds"
)
assert len(staleness.samples) == 0
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
def test_handles_db_error_gracefully(
self,
mock_get_tenants: MagicMock,
) -> None:
collector = ConnectorHealthCollector(cache_ttl=0)
collector.configure()
mock_get_tenants.side_effect = Exception("DB down")
families = collector.collect()
assert families == []

View File

@@ -0,0 +1,96 @@
"""Tests for indexing pipeline setup (Redis factory caching)."""
from unittest.mock import MagicMock
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
def _make_mock_app(client: MagicMock) -> MagicMock:
"""Create a mock Celery app whose broker_connection().channel().client
returns the given client."""
mock_app = MagicMock()
mock_conn = MagicMock()
mock_conn.channel.return_value.client = client
mock_app.broker_connection.return_value = mock_conn
return mock_app
class TestMakeBrokerRedisFactory:
def test_caches_redis_client_across_calls(self) -> None:
"""Factory should reuse the same client on subsequent calls."""
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_app = _make_mock_app(mock_client)
factory = _make_broker_redis_factory(mock_app)
client1 = factory()
client2 = factory()
assert client1 is client2
# broker_connection should only be called once
assert mock_app.broker_connection.call_count == 1
def test_reconnects_when_ping_fails(self) -> None:
"""Factory should create a new client if ping fails (stale connection)."""
mock_client_stale = MagicMock()
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
mock_client_fresh = MagicMock()
mock_client_fresh.ping.return_value = True
mock_app = _make_mock_app(mock_client_stale)
factory = _make_broker_redis_factory(mock_app)
# First call — creates and caches
client1 = factory()
assert client1 is mock_client_stale
assert mock_app.broker_connection.call_count == 1
# Switch to fresh client for next connection
mock_conn_fresh = MagicMock()
mock_conn_fresh.channel.return_value.client = mock_client_fresh
mock_app.broker_connection.return_value = mock_conn_fresh
# Second call — ping fails on stale, reconnects
client2 = factory()
assert client2 is mock_client_fresh
assert mock_app.broker_connection.call_count == 2
def test_reconnect_closes_stale_client(self) -> None:
"""When ping fails, the old client should be closed before reconnecting."""
mock_client_stale = MagicMock()
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
mock_client_fresh = MagicMock()
mock_client_fresh.ping.return_value = True
mock_app = _make_mock_app(mock_client_stale)
factory = _make_broker_redis_factory(mock_app)
# First call — creates and caches
factory()
# Switch to fresh client
mock_conn_fresh = MagicMock()
mock_conn_fresh.channel.return_value.client = mock_client_fresh
mock_app.broker_connection.return_value = mock_conn_fresh
# Second call — ping fails, should close stale client
factory()
mock_client_stale.close.assert_called_once()
def test_first_call_creates_connection(self) -> None:
"""First call should always create a new connection."""
mock_client = MagicMock()
mock_app = _make_mock_app(mock_client)
factory = _make_broker_redis_factory(mock_app)
client = factory()
assert client is mock_client
mock_app.broker_connection.assert_called_once()

View File

@@ -0,0 +1,335 @@
"""Tests for per-connector indexing task Prometheus metrics."""
from collections.abc import Iterator
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.server.metrics.indexing_task_metrics import _connector_cache
from onyx.server.metrics.indexing_task_metrics import _indexing_start_times
from onyx.server.metrics.indexing_task_metrics import ConnectorInfo
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_COMPLETED
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_DURATION
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_STARTED
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
@pytest.fixture(autouse=True)
def reset_state() -> Iterator[None]:
"""Clear caches and state between tests.
Sets CURRENT_TENANT_ID_CONTEXTVAR to a realistic value so cache keys
are never keyed on an empty string.
"""
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test_tenant")
_connector_cache.clear()
_indexing_start_times.clear()
yield
_connector_cache.clear()
_indexing_start_times.clear()
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _make_task(name: str) -> MagicMock:
task = MagicMock()
task.name = name
return task
def _mock_db_lookup(
source: str = "google_drive", name: str = "My Google Drive"
) -> tuple:
"""Return (session_patch, cc_pair_patch) context managers for DB mocking."""
mock_cc_pair = MagicMock()
mock_cc_pair.name = name
mock_cc_pair.connector.source.value = source
session_patch = patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
cc_pair_patch = patch(
"onyx.db.connector_credential_pair.get_connector_credential_pair_from_id",
return_value=mock_cc_pair,
)
return session_patch, cc_pair_patch
class TestIndexingTaskPrerun:
def test_skips_non_indexing_task(self) -> None:
task = _make_task("some_other_task")
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
on_indexing_task_prerun("task-1", task, kwargs)
assert "task-1" not in _indexing_start_times
def test_emits_started_for_docfetching(self) -> None:
# Pre-populate cache to avoid DB lookup (tenant-scoped key)
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
source="google_drive", name="My Google Drive"
)
task = _make_task("connector_doc_fetching_task")
kwargs = {"cc_pair_id": 42, "tenant_id": "tenant-1"}
before = INDEXING_TASK_STARTED.labels(
task_name="connector_doc_fetching_task",
source="google_drive",
tenant_id="tenant-1",
cc_pair_id="42",
)._value.get()
on_indexing_task_prerun("task-1", task, kwargs)
after = INDEXING_TASK_STARTED.labels(
task_name="connector_doc_fetching_task",
source="google_drive",
tenant_id="tenant-1",
cc_pair_id="42",
)._value.get()
assert after == before + 1
assert "task-1" in _indexing_start_times
def test_emits_started_for_docprocessing(self) -> None:
_connector_cache[("test_tenant", 10)] = ConnectorInfo(
source="slack", name="Slack Connector"
)
task = _make_task("docprocessing_task")
kwargs = {"cc_pair_id": 10, "tenant_id": "public"}
on_indexing_task_prerun("task-2", task, kwargs)
assert "task-2" in _indexing_start_times
def test_cache_hit_avoids_db_call(self) -> None:
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
source="confluence", name="Engineering Confluence"
)
task = _make_task("connector_doc_fetching_task")
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
# No DB patches needed — cache should be used
on_indexing_task_prerun("task-1", task, kwargs)
assert "task-1" in _indexing_start_times
def test_db_lookup_on_cache_miss(self) -> None:
"""On first encounter of a cc_pair_id, does a DB lookup and caches."""
mock_cc_pair = MagicMock()
mock_cc_pair.name = "Notion Workspace"
mock_cc_pair.connector.source.value = "notion"
mock_session = MagicMock()
mock_session.__enter__ = MagicMock(return_value=MagicMock())
mock_session.__exit__ = MagicMock(return_value=False)
with (
patch(
"onyx.server.metrics.indexing_task_metrics._resolve_connector"
) as mock_resolve,
):
mock_resolve.return_value = ConnectorInfo(
source="notion", name="Notion Workspace"
)
task = _make_task("connector_doc_fetching_task")
kwargs = {"cc_pair_id": 77, "tenant_id": "public"}
on_indexing_task_prerun("task-1", task, kwargs)
mock_resolve.assert_called_once_with(77)
def test_missing_cc_pair_returns_unknown(self) -> None:
"""When _resolve_connector can't find the cc_pair, uses 'unknown'."""
with patch(
"onyx.server.metrics.indexing_task_metrics._resolve_connector"
) as mock_resolve:
mock_resolve.return_value = ConnectorInfo(source="unknown", name="unknown")
task = _make_task("connector_doc_fetching_task")
kwargs = {"cc_pair_id": 999, "tenant_id": "public"}
on_indexing_task_prerun("task-1", task, kwargs)
assert "task-1" in _indexing_start_times
def test_skips_when_cc_pair_id_missing(self) -> None:
task = _make_task("connector_doc_fetching_task")
kwargs = {"tenant_id": "public"}
on_indexing_task_prerun("task-1", task, kwargs)
assert "task-1" not in _indexing_start_times
def test_db_error_does_not_crash(self) -> None:
with patch(
"onyx.server.metrics.indexing_task_metrics._resolve_connector",
side_effect=Exception("DB down"),
):
task = _make_task("connector_doc_fetching_task")
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
# Should not raise
on_indexing_task_prerun("task-1", task, kwargs)
class TestIndexingTaskPostrun:
def test_skips_non_indexing_task(self) -> None:
task = _make_task("some_other_task")
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
# Should not raise
def test_emits_completed_and_duration(self) -> None:
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
source="google_drive", name="Marketing Drive"
)
task = _make_task("docprocessing_task")
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
# Simulate prerun
on_indexing_task_prerun("task-1", task, kwargs)
before_completed = INDEXING_TASK_COMPLETED.labels(
task_name="docprocessing_task",
source="google_drive",
tenant_id="public",
cc_pair_id="42",
outcome="success",
)._value.get()
before_duration = INDEXING_TASK_DURATION.labels(
task_name="docprocessing_task",
source="google_drive",
tenant_id="public",
)._sum.get()
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
after_completed = INDEXING_TASK_COMPLETED.labels(
task_name="docprocessing_task",
source="google_drive",
tenant_id="public",
cc_pair_id="42",
outcome="success",
)._value.get()
after_duration = INDEXING_TASK_DURATION.labels(
task_name="docprocessing_task",
source="google_drive",
tenant_id="public",
)._sum.get()
assert after_completed == before_completed + 1
assert after_duration > before_duration
def test_failure_outcome(self) -> None:
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
source="slack", name="Slack"
)
task = _make_task("connector_doc_fetching_task")
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
on_indexing_task_prerun("task-1", task, kwargs)
before = INDEXING_TASK_COMPLETED.labels(
task_name="connector_doc_fetching_task",
source="slack",
tenant_id="public",
cc_pair_id="42",
outcome="failure",
)._value.get()
on_indexing_task_postrun("task-1", task, kwargs, "FAILURE")
after = INDEXING_TASK_COMPLETED.labels(
task_name="connector_doc_fetching_task",
source="slack",
tenant_id="public",
cc_pair_id="42",
outcome="failure",
)._value.get()
assert after == before + 1
def test_handles_postrun_without_prerun(self) -> None:
"""Postrun for an indexing task without a matching prerun should not crash."""
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
source="slack", name="Slack"
)
task = _make_task("docprocessing_task")
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
# No prerun — should still emit completed counter, just skip duration
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
class TestResolveConnector:
def test_failed_lookup_not_cached(self) -> None:
"""When DB lookup returns None, result should NOT be cached."""
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
try:
with (
patch("onyx.db.engine.sql_engine.get_session_with_current_tenant"),
patch(
"onyx.db.connector_credential_pair"
".get_connector_credential_pair_from_id",
return_value=None,
),
):
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
result = _resolve_connector(999)
assert result.source == "unknown"
# Should NOT be cached so subsequent calls can retry
assert ("test-tenant", 999) not in _connector_cache
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def test_exception_not_cached(self) -> None:
"""When DB lookup raises, result should NOT be cached."""
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
try:
with (
patch(
"onyx.db.engine.sql_engine.get_session_with_current_tenant",
side_effect=Exception("DB down"),
),
):
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
result = _resolve_connector(888)
assert result.source == "unknown"
assert ("test-tenant", 888) not in _connector_cache
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def test_successful_lookup_is_cached(self) -> None:
"""When DB lookup succeeds, result should be cached."""
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
try:
mock_cc_pair = MagicMock()
mock_cc_pair.name = "My Drive"
mock_cc_pair.connector.source.value = "google_drive"
with (
patch("onyx.db.engine.sql_engine.get_session_with_current_tenant"),
patch(
"onyx.db.connector_credential_pair"
".get_connector_credential_pair_from_id",
return_value=mock_cc_pair,
),
):
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
result = _resolve_connector(777)
assert result.source == "google_drive"
assert result.name == "My Drive"
assert ("test-tenant", 777) in _connector_cache
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

View File

@@ -0,0 +1,69 @@
"""Tests for the Prometheus metrics server module."""
from collections.abc import Iterator
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.server.metrics.metrics_server import _DEFAULT_PORTS
from onyx.server.metrics.metrics_server import start_metrics_server
@pytest.fixture(autouse=True)
def reset_server_state() -> Iterator[None]:
"""Reset the global _server_started between tests."""
import onyx.server.metrics.metrics_server as mod
mod._server_started = False
yield
mod._server_started = False
class TestStartMetricsServer:
@patch("onyx.server.metrics.metrics_server.start_http_server")
def test_uses_default_port_for_known_worker(self, mock_start: MagicMock) -> None:
port = start_metrics_server("monitoring")
assert port == _DEFAULT_PORTS["monitoring"]
mock_start.assert_called_once_with(_DEFAULT_PORTS["monitoring"])
@patch("onyx.server.metrics.metrics_server.start_http_server")
@patch.dict("os.environ", {"PROMETHEUS_METRICS_PORT": "9999"})
def test_env_var_overrides_default(self, mock_start: MagicMock) -> None:
port = start_metrics_server("monitoring")
assert port == 9999
mock_start.assert_called_once_with(9999)
@patch("onyx.server.metrics.metrics_server.start_http_server")
@patch.dict("os.environ", {"PROMETHEUS_METRICS_ENABLED": "false"})
def test_disabled_via_env_var(self, mock_start: MagicMock) -> None:
port = start_metrics_server("monitoring")
assert port is None
mock_start.assert_not_called()
@patch("onyx.server.metrics.metrics_server.start_http_server")
def test_unknown_worker_type_no_env_var(self, mock_start: MagicMock) -> None:
port = start_metrics_server("unknown_worker")
assert port is None
mock_start.assert_not_called()
@patch("onyx.server.metrics.metrics_server.start_http_server")
def test_idempotent(self, mock_start: MagicMock) -> None:
port1 = start_metrics_server("monitoring")
port2 = start_metrics_server("monitoring")
assert port1 == _DEFAULT_PORTS["monitoring"]
assert port2 is None
mock_start.assert_called_once()
@patch("onyx.server.metrics.metrics_server.start_http_server")
def test_handles_os_error(self, mock_start: MagicMock) -> None:
mock_start.side_effect = OSError("Address already in use")
port = start_metrics_server("monitoring")
assert port is None
@patch("onyx.server.metrics.metrics_server.start_http_server")
@patch.dict("os.environ", {"PROMETHEUS_METRICS_PORT": "not_a_number"})
def test_invalid_port_env_var_returns_none(self, mock_start: MagicMock) -> None:
port = start_metrics_server("monitoring")
assert port is None
mock_start.assert_not_called()

View File

@@ -0,0 +1,123 @@
"""Tests for OpenSearch search Prometheus metrics."""
from unittest.mock import patch
from onyx.document_index.opensearch.constants import OpenSearchSearchType
from onyx.server.metrics.opensearch_search import _client_duration
from onyx.server.metrics.opensearch_search import _search_total
from onyx.server.metrics.opensearch_search import _searches_in_progress
from onyx.server.metrics.opensearch_search import _server_duration
from onyx.server.metrics.opensearch_search import observe_opensearch_search
from onyx.server.metrics.opensearch_search import track_opensearch_search_in_progress
class TestObserveOpenSearchSearch:
def test_increments_counter(self) -> None:
search_type = OpenSearchSearchType.HYBRID
before = _search_total.labels(search_type=search_type.value)._value.get()
observe_opensearch_search(search_type, 0.1, 50)
after = _search_total.labels(search_type=search_type.value)._value.get()
assert after == before + 1
def test_observes_client_duration(self) -> None:
search_type = OpenSearchSearchType.KEYWORD
before_sum = _client_duration.labels(search_type=search_type.value)._sum.get()
observe_opensearch_search(search_type, 0.25, 100)
after_sum = _client_duration.labels(search_type=search_type.value)._sum.get()
assert after_sum == before_sum + 0.25
def test_observes_server_duration(self) -> None:
search_type = OpenSearchSearchType.SEMANTIC
before_sum = _server_duration.labels(search_type=search_type.value)._sum.get()
observe_opensearch_search(search_type, 0.3, 200)
after_sum = _server_duration.labels(search_type=search_type.value)._sum.get()
# 200ms should be recorded as 0.2s.
assert after_sum == before_sum + 0.2
def test_server_took_none_skips_server_histogram(self) -> None:
search_type = OpenSearchSearchType.UNKNOWN
before_server = _server_duration.labels(
search_type=search_type.value
)._sum.get()
before_client = _client_duration.labels(
search_type=search_type.value
)._sum.get()
before_total = _search_total.labels(search_type=search_type.value)._value.get()
observe_opensearch_search(search_type, 0.1, None)
# Server histogram should NOT be observed.
after_server = _server_duration.labels(search_type=search_type.value)._sum.get()
assert after_server == before_server
# Client histogram and counter should still work.
after_client = _client_duration.labels(search_type=search_type.value)._sum.get()
after_total = _search_total.labels(search_type=search_type.value)._value.get()
assert after_client == before_client + 0.1
assert after_total == before_total + 1
def test_exceptions_do_not_propagate(self) -> None:
search_type = OpenSearchSearchType.RANDOM
with patch.object(
_search_total.labels(search_type=search_type.value),
"inc",
side_effect=RuntimeError("boom"),
):
# Should not raise.
observe_opensearch_search(search_type, 0.1, 50)
class TestTrackOpenSearchSearchInProgress:
def test_gauge_increments_and_decrements(self) -> None:
search_type = OpenSearchSearchType.HYBRID
before = _searches_in_progress.labels(
search_type=search_type.value
)._value.get()
with track_opensearch_search_in_progress(search_type):
during = _searches_in_progress.labels(
search_type=search_type.value
)._value.get()
assert during == before + 1
after = _searches_in_progress.labels(search_type=search_type.value)._value.get()
assert after == before
def test_gauge_decrements_on_exception(self) -> None:
search_type = OpenSearchSearchType.SEMANTIC
before = _searches_in_progress.labels(
search_type=search_type.value
)._value.get()
raised = False
try:
with track_opensearch_search_in_progress(search_type):
raise ValueError("simulated search failure")
except ValueError:
raised = True
assert raised
after = _searches_in_progress.labels(search_type=search_type.value)._value.get()
assert after == before
def test_inc_exception_does_not_break_search(self) -> None:
search_type = OpenSearchSearchType.KEYWORD
before = _searches_in_progress.labels(
search_type=search_type.value
)._value.get()
with patch.object(
_searches_in_progress.labels(search_type=search_type.value),
"inc",
side_effect=RuntimeError("boom"),
):
# Context manager should still yield without decrementing.
with track_opensearch_search_in_progress(search_type):
# Search logic would execute here.
during = _searches_in_progress.labels(
search_type=search_type.value
)._value.get()
assert during == before
after = _searches_in_progress.labels(search_type=search_type.value)._value.get()
assert after == before

View File

@@ -169,6 +169,21 @@ Engine label values: `sync` (main read-write), `async` (async sessions), `readon
Connections from background tasks (Celery) or boot-time warmup appear as `handler="unknown"`.
## OpenSearch Search Metrics
These metrics track OpenSearch search latency and throughput. Collected via `onyx.server.metrics.opensearch_search`.
| Metric | Type | Labels | Description |
|--------|------|--------|-------------|
| `onyx_opensearch_search_client_duration_seconds` | Histogram | `search_type` | Client-side end-to-end latency (network + serialization + server execution) |
| `onyx_opensearch_search_server_duration_seconds` | Histogram | `search_type` | Server-side execution time from OpenSearch `took` field |
| `onyx_opensearch_search_total` | Counter | `search_type` | Total search requests sent to OpenSearch |
| `onyx_opensearch_searches_in_progress` | Gauge | `search_type` | Currently in-flight OpenSearch searches |
Search type label values: See `OpenSearchSearchType`.
---
## Example PromQL Queries
### Which endpoints are saturated right now?
@@ -258,3 +273,33 @@ histogram_quantile(0.99, sum by (handler, le) (rate(onyx_db_connection_hold_seco
# Checkouts per second by engine
sum by (engine) (rate(onyx_db_pool_checkout_total[5m]))
```
### OpenSearch P99 search latency by type
```promql
# P99 client-side latency by search type
histogram_quantile(0.99, sum by (search_type, le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))
```
### OpenSearch search throughput
```promql
# Searches per second by type
sum by (search_type) (rate(onyx_opensearch_search_total[5m]))
```
### OpenSearch concurrent searches
```promql
# Total in-flight searches across all instances
sum(onyx_opensearch_searches_in_progress)
```
### OpenSearch network overhead
```promql
# Difference between client and server P50 reveals network/serialization cost.
histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))
-
histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))
```

View File

@@ -6271,9 +6271,9 @@
"license": "ISC"
},
"node_modules/picomatch": {
"version": "2.3.1",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz",
"integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==",
"version": "2.3.2",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
"dev": true,
"license": "MIT",
"engines": {
@@ -7179,9 +7179,9 @@
}
},
"node_modules/tinyglobby/node_modules/picomatch": {
"version": "4.0.3",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"version": "4.0.4",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
"dev": true,
"license": "MIT",
"engines": {

View File

@@ -92,7 +92,7 @@ backend = [
"python-gitlab==5.6.0",
"python-pptx==0.6.23",
"pypandoc_binary==1.16.2",
"pypdf==6.9.1",
"pypdf==6.9.2",
"pytest-mock==3.12.0",
"pytest-playwright==0.7.0",
"python-docx==1.1.2",
@@ -100,7 +100,7 @@ backend = [
"python-multipart==0.0.22",
"pywikibot==9.0.0",
"redis==5.0.8",
"requests==2.32.5",
"requests==2.33.0",
"requests-oauthlib==1.3.1",
"rfc3986==1.5.0",
"simple-salesforce==1.12.6",

22
uv.lock generated
View File

@@ -3909,7 +3909,7 @@ wheels = [
[[package]]
name = "nltk"
version = "3.9.3"
version = "3.9.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
@@ -3917,9 +3917,9 @@ dependencies = [
{ name = "regex" },
{ name = "tqdm" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e1/8f/915e1c12df07c70ed779d18ab83d065718a926e70d3ea33eb0cd66ffb7c0/nltk-3.9.3.tar.gz", hash = "sha256:cb5945d6424a98d694c2b9a0264519fab4363711065a46aa0ae7a2195b92e71f", size = 2923673, upload-time = "2026-02-24T12:05:53.833Z" }
sdist = { url = "https://files.pythonhosted.org/packages/74/a1/b3b4adf15585a5bc4c357adde150c01ebeeb642173ded4d871e89468767c/nltk-3.9.4.tar.gz", hash = "sha256:ed03bc098a40481310320808b2db712d95d13ca65b27372f8a403949c8b523d0", size = 2946864, upload-time = "2026-03-24T06:13:40.641Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c2/7e/9af5a710a1236e4772de8dfcc6af942a561327bb9f42b5b4a24d0cf100fd/nltk-3.9.3-py3-none-any.whl", hash = "sha256:60b3db6e9995b3dd976b1f0fa7dec22069b2677e759c28eb69b62ddd44870522", size = 1525385, upload-time = "2026-02-24T12:05:46.54Z" },
{ url = "https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl", hash = "sha256:f2fa301c3a12718ce4a0e9305c5675299da5ad9e26068218b69d692fda84828f", size = 1552087, upload-time = "2026-03-24T06:13:38.47Z" },
]
[[package]]
@@ -4481,7 +4481,7 @@ requires-dist = [
{ name = "pygithub", marker = "extra == 'backend'", specifier = "==2.5.0" },
{ name = "pympler", marker = "extra == 'backend'", specifier = "==1.1" },
{ name = "pypandoc-binary", marker = "extra == 'backend'", specifier = "==1.16.2" },
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.9.1" },
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.9.2" },
{ name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" },
{ name = "pytest-alembic", marker = "extra == 'dev'", specifier = "==0.12.1" },
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==1.3.0" },
@@ -4502,7 +4502,7 @@ requires-dist = [
{ name = "redis", marker = "extra == 'backend'", specifier = "==5.0.8" },
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.5.2" },
{ name = "reorder-python-imports-black", marker = "extra == 'dev'", specifier = "==3.14.0" },
{ name = "requests", marker = "extra == 'backend'", specifier = "==2.32.5" },
{ name = "requests", marker = "extra == 'backend'", specifier = "==2.33.0" },
{ name = "requests-oauthlib", marker = "extra == 'backend'", specifier = "==1.3.1" },
{ name = "retry", specifier = "==0.9.2" },
{ name = "rfc3986", marker = "extra == 'backend'", specifier = "==1.5.0" },
@@ -5727,11 +5727,11 @@ wheels = [
[[package]]
name = "pypdf"
version = "6.9.1"
version = "6.9.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f9/fb/dc2e8cb006e80b0020ed20d8649106fe4274e82d8e756ad3e24ade19c0df/pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d", size = 5311551, upload-time = "2026-03-17T10:46:07.876Z" }
sdist = { url = "https://files.pythonhosted.org/packages/31/83/691bdb309306232362503083cb15777491045dd54f45393a317dc7d8082f/pypdf-6.9.2.tar.gz", hash = "sha256:7f850faf2b0d4ab936582c05da32c52214c2b089d61a316627b5bfb5b0dab46c", size = 5311837, upload-time = "2026-03-23T14:53:27.983Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f9/f4/75543fa802b86e72f87e9395440fe1a89a6d149887e3e55745715c3352ac/pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f", size = 333661, upload-time = "2026-03-17T10:46:06.286Z" },
{ url = "https://files.pythonhosted.org/packages/a5/7e/c85f41243086a8fe5d1baeba527cb26a1918158a565932b41e0f7c0b32e9/pypdf-6.9.2-py3-none-any.whl", hash = "sha256:662cf29bcb419a36a1365232449624ab40b7c2d0cfc28e54f42eeecd1fd7e844", size = 333744, upload-time = "2026-03-23T14:53:26.573Z" },
]
[[package]]
@@ -6378,7 +6378,7 @@ wheels = [
[[package]]
name = "requests"
version = "2.32.5"
version = "2.33.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "certifi" },
@@ -6386,9 +6386,9 @@ dependencies = [
{ name = "idna" },
{ name = "urllib3" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" }
sdist = { url = "https://files.pythonhosted.org/packages/34/64/8860370b167a9721e8956ae116825caff829224fbca0ca6e7bf8ddef8430/requests-2.33.0.tar.gz", hash = "sha256:c7ebc5e8b0f21837386ad0e1c8fe8b829fa5f544d8df3b2253bff14ef29d7652", size = 134232, upload-time = "2026-03-25T15:10:41.586Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
{ url = "https://files.pythonhosted.org/packages/56/5d/c814546c2333ceea4ba42262d8c4d55763003e767fa169adc693bd524478/requests-2.33.0-py3-none-any.whl", hash = "sha256:3324635456fa185245e24865e810cecec7b4caf933d7eb133dcde67d48cee69b", size = 65017, upload-time = "2026-03-25T15:10:40.382Z" },
]
[[package]]

View File

@@ -281,35 +281,90 @@ If you need help with this step, reach out to `raunak@onyx.app`.
## 3. Text Rendering
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
**Use the Opal `Text` component for all text rendering. Avoid "naked" text nodes.**
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It uses
string-enum props (`font` and `color`) for font preset and color selection. Inline markdown is
opt-in via the `markdown()` function from `@opal/types`.
```typescript
// ✅ Good
import { Text } from '@/refresh-components/texts/Text'
// ✅ Good — Opal Text with string-enum props
import { Text } from "@opal/components";
function UserCard({ name }: { name: string }) {
return (
<Text
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
text03
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
mainAction
>
<Text font="main-ui-action" color="text-03">
{name}
</Text>
)
}
// ❌ Bad
function UserCard({ name }: { name: string }) {
// ✅ Good — inline markdown via markdown()
import { markdown } from "@opal/utils";
<Text font="main-ui-body" color="text-05">
{markdown("*Hello*, **world**! Visit [Onyx](https://onyx.app) and run `onyx start`.")}
</Text>
// ✅ Good — plain strings are never parsed as markdown
<Text font="main-ui-body" color="text-03">
{userProvidedString}
</Text>
// ✅ Good — component props that support optional markdown use `string | RichStr`
import type { RichStr } from "@opal/types";
interface MyCardProps {
title: string | RichStr;
}
// ❌ Bad — legacy boolean-flag API (still works but deprecated)
import Text from "@/refresh-components/texts/Text";
<Text text03 mainUiAction>{name}</Text>
// ❌ Bad — naked text nodes
<div>
<h2>{name}</h2>
<p>User details</p>
</div>
```
Key props:
- `font`: `TextFont` — font preset (e.g., `"main-ui-body"`, `"heading-h2"`, `"secondary-action"`)
- `color`: `TextColor` — text color (e.g., `"text-03"`, `"text-inverted-05"`)
- `as`: `"p" | "span" | "li" | "h1" | "h2" | "h3"` — HTML tag (default: `"span"`)
- `nowrap`: `boolean` — prevent text wrapping
**`RichStr` convention:** When creating new components, any string prop that will be rendered as
visible text in the DOM (e.g., `title`, `description`, `label`) should be typed as
`string | RichStr` instead of plain `string`. This gives callers opt-in markdown support via
`markdown()` without requiring any additional props or API surface on the component.
```typescript
import type { RichStr } from "@opal/types";
import { resolveStr } from "@opal/components/text/InlineMarkdown";
// ✅ Good — new components accept string | RichStr
interface InfoCardProps {
title: string | RichStr;
description?: string | RichStr;
}
function InfoCard({ title, description }: InfoCardProps) {
return (
<div>
<h2>{name}</h2>
<p>User details</p>
<Text font="main-ui-action">{resolveStr(title)}</Text>
{description && (
<Text font="secondary-body" color="text-03">{resolveStr(description)}</Text>
)}
</div>
)
);
}
// ❌ Bad — plain string props block markdown support for callers
interface InfoCardProps {
title: string;
description?: string;
}
```

View File

@@ -33,6 +33,14 @@ export {
type LineItemButtonProps,
} from "@opal/components/buttons/line-item-button/components";
/* Text */
export {
Text,
type TextProps,
type TextFont,
type TextColor,
} from "@opal/components/text/components";
/* Tag */
export {
Tag,

View File

@@ -145,6 +145,8 @@ export function Table<TData>(props: DataTableProps<TData>) {
pageSize,
initialSorting,
initialColumnVisibility,
initialRowSelection,
initialViewSelected,
draggable,
footer,
size = "lg",
@@ -221,6 +223,8 @@ export function Table<TData>(props: DataTableProps<TData>) {
pageSize: effectivePageSize,
initialSorting,
initialColumnVisibility,
initialRowSelection,
initialViewSelected,
getRowId,
onSelectionChange,
searchTerm,

View File

@@ -103,6 +103,10 @@ interface UseDataTableOptions<TData extends RowData> {
initialSorting?: SortingState;
/** Initial column visibility state. @default {} */
initialColumnVisibility?: VisibilityState;
/** Initial row selection state. Keys are row IDs (from `getRowId`), values are `true`. @default {} */
initialRowSelection?: RowSelectionState;
/** When true AND `initialRowSelection` is non-empty, start in view-selected mode (filtered to selected rows). @default false */
initialViewSelected?: boolean;
/** Called whenever the set of selected row IDs changes. */
onSelectionChange?: (selectedIds: string[]) => void;
/** Search term for global text filtering. Rows are filtered to those containing
@@ -195,6 +199,8 @@ export default function useDataTable<TData extends RowData>(
columnResizeMode = "onChange",
initialSorting = [],
initialColumnVisibility = {},
initialRowSelection = {},
initialViewSelected = false,
getRowId,
onSelectionChange,
searchTerm,
@@ -206,7 +212,8 @@ export default function useDataTable<TData extends RowData>(
// ---- internal state -----------------------------------------------------
const [sorting, setSorting] = useState<SortingState>(initialSorting);
const [rowSelection, setRowSelection] = useState<RowSelectionState>({});
const [rowSelection, setRowSelection] =
useState<RowSelectionState>(initialRowSelection);
const [columnSizing, setColumnSizing] = useState<ColumnSizingState>({});
const [columnVisibility, setColumnVisibility] = useState<VisibilityState>(
initialColumnVisibility
@@ -216,8 +223,12 @@ export default function useDataTable<TData extends RowData>(
pageSize: pageSizeOption,
});
/** Combined global filter: view-mode (selected IDs) + text search. */
const initialSelectedIds =
initialViewSelected && Object.keys(initialRowSelection).length > 0
? new Set(Object.keys(initialRowSelection))
: null;
const [globalFilter, setGlobalFilter] = useState<GlobalFilterValue>({
selectedIds: null,
selectedIds: initialSelectedIds,
searchTerm: "",
});
@@ -384,6 +395,31 @@ export default function useDataTable<TData extends RowData>(
: data.length;
const isPaginated = isFinite(pagination.pageSize);
// ---- keep view-mode filter in sync with selection ----------------------
// When in view-selected mode, deselecting a row should remove it from
// the visible set so it disappears immediately.
useEffect(() => {
if (isServerSide) return;
if (globalFilter.selectedIds == null) return;
const currentIds = new Set(Object.keys(rowSelection));
// Remove any ID from the filter that is no longer selected
let changed = false;
const next = new Set<string>();
globalFilter.selectedIds.forEach((id) => {
if (currentIds.has(id)) {
next.add(id);
} else {
changed = true;
}
});
if (changed) {
setGlobalFilter((prev) => ({ ...prev, selectedIds: next }));
}
// eslint-disable-next-line react-hooks/exhaustive-deps -- only react to
// selection changes while in view mode
}, [rowSelection, isServerSide]);
// ---- selection change callback ------------------------------------------
const isFirstRenderRef = useRef(true);
const onSelectionChangeRef = useRef(onSelectionChange);
@@ -392,6 +428,10 @@ export default function useDataTable<TData extends RowData>(
useEffect(() => {
if (isFirstRenderRef.current) {
isFirstRenderRef.current = false;
// Still fire the callback on first render if there's an initial selection
if (selectedRowIds.length > 0) {
onSelectionChangeRef.current?.(selectedRowIds);
}
return;
}
onSelectionChangeRef.current?.(selectedRowIds);

View File

@@ -146,6 +146,10 @@ export interface DataTableProps<TData> {
initialSorting?: SortingState;
/** Initial column visibility state. */
initialColumnVisibility?: VisibilityState;
/** Initial row selection state. Keys are row IDs (from `getRowId`), values are `true`. */
initialRowSelection?: Record<string, boolean>;
/** When true AND `initialRowSelection` is non-empty, start in view-selected mode. @default false */
initialViewSelected?: boolean;
/** Enable drag-and-drop row reordering. */
draggable?: DataTableDraggableConfig;
/** Footer configuration. */

View File

@@ -0,0 +1,76 @@
import type { ReactNode } from "react";
import ReactMarkdown from "react-markdown";
import remarkGfm from "remark-gfm";
import type { RichStr } from "@opal/types";
// ---------------------------------------------------------------------------
// InlineMarkdown
// ---------------------------------------------------------------------------
const SAFE_PROTOCOL = /^https?:|^mailto:|^tel:/i;
const ALLOWED_ELEMENTS = ["p", "a", "strong", "em", "code", "del"];
const INLINE_COMPONENTS = {
p: ({ children }: { children?: ReactNode }) => <>{children}</>,
a: ({ children, href }: { children?: ReactNode; href?: string }) => {
if (!href || !SAFE_PROTOCOL.test(href)) {
return <>{children}</>;
}
const isHttp = /^https?:/i.test(href);
return (
<a
href={href}
className="underline underline-offset-2"
{...(isHttp ? { target: "_blank", rel: "noopener noreferrer" } : {})}
>
{children}
</a>
);
},
code: ({ children }: { children?: ReactNode }) => (
<code className="[font-family:var(--font-dm-mono)] bg-background-tint-02 rounded px-1 py-0.5">
{children}
</code>
),
};
interface InlineMarkdownProps {
content: string;
}
export default function InlineMarkdown({ content }: InlineMarkdownProps) {
return (
<ReactMarkdown
components={INLINE_COMPONENTS}
allowedElements={ALLOWED_ELEMENTS}
unwrapDisallowed
remarkPlugins={[remarkGfm]}
>
{content}
</ReactMarkdown>
);
}
// ---------------------------------------------------------------------------
// RichStr helpers
// ---------------------------------------------------------------------------
function isRichStr(value: unknown): value is RichStr {
return (
typeof value === "object" &&
value !== null &&
(value as RichStr).__brand === "RichStr"
);
}
/** Resolves `string | RichStr` to a `ReactNode`. */
export function resolveStr(value: string | RichStr): ReactNode {
return isRichStr(value) ? <InlineMarkdown content={value.raw} /> : value;
}
/** Extracts the plain string from `string | RichStr`. */
export function toPlainString(value: string | RichStr): string {
return isRichStr(value) ? value.raw : value;
}

View File

@@ -0,0 +1,124 @@
# Text
**Import:** `import { Text, type TextProps, type TextFont, type TextColor } from "@opal/components";`
A styled text component with string-enum props for font preset and color selection. Supports
inline markdown rendering via `RichStr` — pass `markdown("*bold* text")` as children to enable.
## Props
| Prop | Type | Default | Description |
|---|---|---|---|
| `font` | `TextFont` | `"main-ui-body"` | Font preset (size, weight, line-height) |
| `color` | `TextColor` | `"text-04"` | Text color |
| `as` | `"p" \| "span" \| "li" \| "h1" \| "h2" \| "h3"` | `"span"` | HTML tag to render |
| `nowrap` | `boolean` | `false` | Prevent text wrapping |
| `children` | `string \| RichStr` | — | Plain string or `markdown()` for inline markdown |
### `TextFont`
| Value | Size | Weight | Line-height |
|---|---|---|---|
| `"heading-h1"` | 48px | 600 | 64px |
| `"heading-h2"` | 24px | 600 | 36px |
| `"heading-h3"` | 18px | 600 | 28px |
| `"heading-h3-muted"` | 18px | 500 | 28px |
| `"main-content-body"` | 16px | 450 | 24px |
| `"main-content-muted"` | 16px | 400 | 24px |
| `"main-content-emphasis"` | 16px | 700 | 24px |
| `"main-content-mono"` | 16px | 400 | 23px |
| `"main-ui-body"` | 14px | 500 | 20px |
| `"main-ui-muted"` | 14px | 400 | 20px |
| `"main-ui-action"` | 14px | 600 | 20px |
| `"main-ui-mono"` | 14px | 400 | 20px |
| `"secondary-body"` | 12px | 400 | 18px |
| `"secondary-action"` | 12px | 600 | 18px |
| `"secondary-mono"` | 12px | 400 | 18px |
| `"figure-small-label"` | 10px | 600 | 14px |
| `"figure-small-value"` | 10px | 400 | 14px |
| `"figure-keystroke"` | 11px | 400 | 16px |
### `TextColor`
`"text-01" | "text-02" | "text-03" | "text-04" | "text-05" | "text-inverted-01" | "text-inverted-02" | "text-inverted-03" | "text-inverted-04" | "text-inverted-05" | "text-light-03" | "text-light-05" | "text-dark-03" | "text-dark-05"`
## Usage Examples
```tsx
import { Text } from "@opal/components";
// Basic
<Text font="main-ui-body" color="text-03">
Hello world
</Text>
// Heading
<Text font="heading-h2" color="text-05" as="h2">
Page Title
</Text>
// Inverted (for dark backgrounds)
<Text font="main-ui-body" color="text-inverted-05">
Light text on dark
</Text>
// As paragraph
<Text font="main-content-body" color="text-03" as="p">
A full paragraph of text.
</Text>
```
## Inline Markdown via `RichStr`
Inline markdown is opt-in via the `markdown()` function, which returns a `RichStr`. When `Text`
receives a `RichStr` as children, it parses the inner string as inline markdown. Plain strings
are rendered as-is — no parsing, no surprises. `Text` does not accept arbitrary JSX as children;
use `string | RichStr` only.
```tsx
import { Text } from "@opal/components";
import { markdown } from "@opal/utils";
// Inline markdown — bold, italic, links, code, strikethrough
<Text font="main-ui-body" color="text-05">
{markdown("*Hello*, **world**! Visit [Onyx](https://onyx.app) and run `onyx start`.")}
</Text>
// Plain string — no markdown parsing
<Text font="main-ui-body" color="text-03">
This *stays* as-is, no formatting applied.
</Text>
```
Supported syntax: `**bold**`, `*italic*`, `` `code` ``, `[link](url)`, `~~strikethrough~~`.
Markdown rendering uses `react-markdown` internally, restricted to inline elements only.
`http(s)` links open in a new tab; `mailto:` and `tel:` links open natively. Inline code
inherits the parent font size and switches to the monospace family.
**Note:** This is inline-only markdown. Multi-paragraph content (`"Hello\n\nWorld"`) will
collapse into a single run of text since paragraph wrappers are stripped. For block-level
markdown, use `MinimalMarkdown` instead.
### Using `RichStr` in component props
Components that want to support optional markdown in their text props should accept
`string | RichStr`:
```tsx
import type { RichStr } from "@opal/types";
interface MyComponentProps {
title: string | RichStr;
description?: string | RichStr;
}
```
This avoids API coloring — no `markdown` boolean needs to be threaded through intermediate
components. The decision to use markdown lives at the call site.
## Compatibility
`@/refresh-components/texts/Text` is an independent legacy component that implements the same
font/color presets via a boolean-flag API. It is **not** a wrapper around this component. New
code should import directly from `@opal/components`.

View File

@@ -0,0 +1,257 @@
import type { Meta, StoryObj } from "@storybook/react";
import { Text } from "@opal/components";
import type { TextFont, TextColor } from "@opal/components";
import { markdown } from "@opal/utils";
const meta: Meta<typeof Text> = {
title: "opal/components/Text",
component: Text,
tags: ["autodocs"],
};
export default meta;
type Story = StoryObj<typeof Text>;
// ---------------------------------------------------------------------------
// Basic
// ---------------------------------------------------------------------------
export const Default: Story = {
args: {
children: "The quick brown fox jumps over the lazy dog",
},
};
export const AsHeading: Story = {
args: {
font: "heading-h2",
color: "text-05",
as: "h2",
children: "Page Title",
},
};
export const AsParagraph: Story = {
args: {
font: "main-content-body",
color: "text-03",
as: "p",
children: "A full paragraph of body text rendered as a p element.",
},
};
export const Nowrap: Story = {
render: () => (
<div className="w-48 border border-border-02 rounded p-2">
<Text font="main-ui-body" color="text-05" nowrap>
This text will not wrap even though the container is narrow
</Text>
</div>
),
};
// ---------------------------------------------------------------------------
// Fonts
// ---------------------------------------------------------------------------
const ALL_FONTS: TextFont[] = [
"heading-h1",
"heading-h2",
"heading-h3",
"heading-h3-muted",
"main-content-body",
"main-content-muted",
"main-content-emphasis",
"main-content-mono",
"main-ui-body",
"main-ui-muted",
"main-ui-action",
"main-ui-mono",
"secondary-body",
"secondary-action",
"secondary-mono",
"figure-small-label",
"figure-small-value",
"figure-keystroke",
];
export const AllFonts: Story = {
render: () => (
<div className="space-y-2">
{ALL_FONTS.map((font) => (
<div key={font} className="flex items-baseline gap-4">
<span className="w-56 shrink-0 font-secondary-body text-text-03">
{font}
</span>
<Text font={font} color="text-05">
The quick brown fox
</Text>
</div>
))}
</div>
),
};
// ---------------------------------------------------------------------------
// Colors
// ---------------------------------------------------------------------------
const STANDARD_COLORS: TextColor[] = [
"text-01",
"text-02",
"text-03",
"text-04",
"text-05",
];
const INVERTED_COLORS: TextColor[] = [
"text-inverted-01",
"text-inverted-02",
"text-inverted-03",
"text-inverted-04",
"text-inverted-05",
];
export const AllColors: Story = {
render: () => (
<div className="space-y-2">
{STANDARD_COLORS.map((color) => (
<div key={color} className="flex items-baseline gap-4">
<span className="w-56 shrink-0 font-secondary-body text-text-03">
{color}
</span>
<Text font="main-ui-body" color={color}>
The quick brown fox
</Text>
</div>
))}
</div>
),
};
export const InvertedColors: Story = {
render: () => (
<div className="bg-background-inverted-01 rounded-lg p-6 space-y-2">
{INVERTED_COLORS.map((color) => (
<div key={color} className="flex items-baseline gap-4">
<span
className="w-56 shrink-0 font-secondary-body"
style={{ color: "rgba(255,255,255,0.5)" }}
>
{color}
</span>
<Text font="main-ui-body" color={color}>
The quick brown fox
</Text>
</div>
))}
</div>
),
};
// ---------------------------------------------------------------------------
// Markdown via RichStr
// ---------------------------------------------------------------------------
export const MarkdownBold: Story = {
args: {
font: "main-ui-body",
color: "text-05",
children: markdown("This is **bold** text"),
},
};
export const MarkdownItalic: Story = {
args: {
font: "main-ui-body",
color: "text-05",
children: markdown("This is *italic* text"),
},
};
export const MarkdownCode: Story = {
args: {
font: "main-ui-body",
color: "text-05",
children: markdown("Run `npm install` to get started"),
},
};
export const MarkdownLink: Story = {
args: {
font: "main-ui-body",
color: "text-05",
children: markdown("Visit [Onyx](https://www.onyx.app/) for more info"),
},
};
export const MarkdownStrikethrough: Story = {
args: {
font: "main-ui-body",
color: "text-05",
children: markdown("This is ~~deleted~~ text"),
},
};
export const MarkdownCombined: Story = {
args: {
font: "main-ui-body",
color: "text-05",
children: markdown(
"*Hello*, **world**! Check out [Onyx](https://www.onyx.app/) and run `onyx start` to begin."
),
},
};
export const MarkdownAtDifferentSizes: Story = {
render: () => (
<div className="space-y-3">
<Text font="heading-h2" color="text-05" as="h2">
{markdown("**Heading** with *emphasis* and `code`")}
</Text>
<Text font="main-content-body" color="text-03" as="p">
{markdown("**Main content** with *emphasis* and `code`")}
</Text>
<Text font="secondary-body" color="text-03">
{markdown("**Secondary** with *emphasis* and `code`")}
</Text>
</div>
),
};
export const PlainStringNotParsed: Story = {
render: () => (
<div className="space-y-2">
<Text font="main-ui-body" color="text-05">
{
"This has *asterisks* and **double asterisks** but they are NOT parsed."
}
</Text>
</div>
),
};
// ---------------------------------------------------------------------------
// Tag Variants
// ---------------------------------------------------------------------------
export const TagVariants: Story = {
render: () => (
<div className="space-y-2">
<Text font="main-ui-body" color="text-05">
Default (span): inline text
</Text>
<Text font="main-ui-body" color="text-05" as="p">
Paragraph (p): block text
</Text>
<Text font="heading-h2" color="text-05" as="h2">
Heading (h2): semantic heading
</Text>
<ul className="list-disc pl-6">
<Text font="main-ui-body" color="text-05" as="li">
List item (li): inside a list
</Text>
</ul>
</div>
),
};

View File

@@ -0,0 +1,134 @@
import type { HTMLAttributes } from "react";
import type { RichStr, WithoutStyles } from "@opal/types";
import { cn } from "@opal/utils";
import { resolveStr } from "@opal/components/text/InlineMarkdown";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
type TextFont =
| "heading-h1"
| "heading-h2"
| "heading-h3"
| "heading-h3-muted"
| "main-content-body"
| "main-content-muted"
| "main-content-emphasis"
| "main-content-mono"
| "main-ui-body"
| "main-ui-muted"
| "main-ui-action"
| "main-ui-mono"
| "secondary-body"
| "secondary-action"
| "secondary-mono"
| "figure-small-label"
| "figure-small-value"
| "figure-keystroke";
type TextColor =
| "text-01"
| "text-02"
| "text-03"
| "text-04"
| "text-05"
| "text-inverted-01"
| "text-inverted-02"
| "text-inverted-03"
| "text-inverted-04"
| "text-inverted-05"
| "text-light-03"
| "text-light-05"
| "text-dark-03"
| "text-dark-05";
interface TextProps
extends WithoutStyles<
Omit<HTMLAttributes<HTMLElement>, "color" | "children">
> {
/** Font preset. Default: `"main-ui-body"`. */
font?: TextFont;
/** Color variant. Default: `"text-04"`. */
color?: TextColor;
/** HTML tag to render. Default: `"span"`. */
as?: "p" | "span" | "li" | "h1" | "h2" | "h3";
/** Prevent text wrapping. */
nowrap?: boolean;
/** Plain string or `markdown()` for inline markdown. */
children?: string | RichStr;
}
// ---------------------------------------------------------------------------
// Config
// ---------------------------------------------------------------------------
const FONT_CONFIG: Record<TextFont, string> = {
"heading-h1": "font-heading-h1",
"heading-h2": "font-heading-h2",
"heading-h3": "font-heading-h3",
"heading-h3-muted": "font-heading-h3-muted",
"main-content-body": "font-main-content-body",
"main-content-muted": "font-main-content-muted",
"main-content-emphasis": "font-main-content-emphasis",
"main-content-mono": "font-main-content-mono",
"main-ui-body": "font-main-ui-body",
"main-ui-muted": "font-main-ui-muted",
"main-ui-action": "font-main-ui-action",
"main-ui-mono": "font-main-ui-mono",
"secondary-body": "font-secondary-body",
"secondary-action": "font-secondary-action",
"secondary-mono": "font-secondary-mono",
"figure-small-label": "font-figure-small-label",
"figure-small-value": "font-figure-small-value",
"figure-keystroke": "font-figure-keystroke",
};
const COLOR_CONFIG: Record<TextColor, string> = {
"text-01": "text-text-01",
"text-02": "text-text-02",
"text-03": "text-text-03",
"text-04": "text-text-04",
"text-05": "text-text-05",
"text-inverted-01": "text-text-inverted-01",
"text-inverted-02": "text-text-inverted-02",
"text-inverted-03": "text-text-inverted-03",
"text-inverted-04": "text-text-inverted-04",
"text-inverted-05": "text-text-inverted-05",
"text-light-03": "text-text-light-03",
"text-light-05": "text-text-light-05",
"text-dark-03": "text-text-dark-03",
"text-dark-05": "text-text-dark-05",
};
// ---------------------------------------------------------------------------
// Text
// ---------------------------------------------------------------------------
function Text({
font = "main-ui-body",
color = "text-04",
as: Tag = "span",
nowrap,
children,
...rest
}: TextProps) {
const resolvedClassName = cn(
FONT_CONFIG[font],
COLOR_CONFIG[color],
nowrap && "whitespace-nowrap"
);
return (
<Tag {...rest} className={resolvedClassName}>
{children && resolveStr(children)}
</Tag>
);
}
export { Text, type TextProps, type TextFont, type TextColor };

View File

@@ -0,0 +1,21 @@
import type { IconProps } from "@opal/types";
const SvgFileBroadcast = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 18 18"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M6.1875 2.25003H2.625C1.808 2.25003 1.125 2.93303 1.125 3.75003L1.125 14.25C1.125 15.067 1.808 15.75 2.625 15.75L9.37125 15.75C10.1883 15.75 10.8713 15.067 10.8713 14.25L10.8713 6.94128M6.1875 2.25003L10.8713 6.94128M6.1875 2.25003V6.94128H10.8713M10.3069 2.25L13.216 5.15914C13.6379 5.5811 13.875 6.15339 13.875 6.75013V13.875C13.875 14.5212 13.737 15.2081 13.4392 15.7538M16.4391 15.7538C16.737 15.2081 16.875 14.5213 16.875 13.8751L16.875 7.02481C16.875 5.53418 16.2833 4.10451 15.23 3.04982L14.4301 2.25003"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgFileBroadcast;

View File

@@ -0,0 +1,21 @@
import type { IconProps } from "@opal/types";
const SvgHookNodes = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M10.0002 4C10.0002 3.99708 10.0002 3.99415 10.0001 3.99123C9.99542 2.8907 9.10181 2 8.00016 2C6.89559 2 6.00016 2.89543 6.00016 4C6.00016 4.73701 6.39882 5.38092 6.99226 5.72784L4.67276 9.70412M11.6589 13.7278C11.9549 13.9009 12.2993 14 12.6668 14C13.7714 14 14.6668 13.1046 14.6668 12C14.6668 10.8954 13.7714 10 12.6668 10C12.2993 10 11.9549 10.0991 11.6589 10.2722L9.33943 6.29588M2.33316 10.2678C1.73555 10.6136 1.3335 11.2599 1.3335 12C1.3335 13.1046 2.22893 14 3.3335 14C4.43807 14 5.3335 13.1046 5.3335 12H10.0002"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgHookNodes;

View File

@@ -69,8 +69,9 @@ export { default as SvgExternalLink } from "@opal/icons/external-link";
export { default as SvgEye } from "@opal/icons/eye";
export { default as SvgEyeClosed } from "@opal/icons/eye-closed";
export { default as SvgEyeOff } from "@opal/icons/eye-off";
export { default as SvgFiles } from "@opal/icons/files";
export { default as SvgFileBraces } from "@opal/icons/file-braces";
export { default as SvgFileBroadcast } from "@opal/icons/file-broadcast";
export { default as SvgFiles } from "@opal/icons/files";
export { default as SvgFileChartPie } from "@opal/icons/file-chart-pie";
export { default as SvgFileSmall } from "@opal/icons/file-small";
export { default as SvgFileText } from "@opal/icons/file-text";
@@ -90,6 +91,7 @@ export { default as SvgHashSmall } from "@opal/icons/hash-small";
export { default as SvgHash } from "@opal/icons/hash";
export { default as SvgHeadsetMic } from "@opal/icons/headset-mic";
export { default as SvgHistory } from "@opal/icons/history";
export { default as SvgHookNodes } from "@opal/icons/hook-nodes";
export { default as SvgHourglass } from "@opal/icons/hourglass";
export { default as SvgImage } from "@opal/icons/image";
export { default as SvgImageSmall } from "@opal/icons/image-small";

View File

@@ -55,6 +55,9 @@ interface ContentMdProps {
/** When `true`, renders "(Optional)" beside the title. */
optional?: boolean;
/** Custom muted suffix rendered beside the title. */
titleSuffix?: string;
/** Auxiliary status icon rendered beside the title. */
auxIcon?: ContentMdAuxIcon;
@@ -138,6 +141,7 @@ function ContentMd({
editable,
onTitleChange,
optional,
titleSuffix,
auxIcon,
tag,
sizePreset = "main-ui",
@@ -234,12 +238,12 @@ function ContentMd({
</span>
)}
{optional && (
{(optional || titleSuffix) && (
<span
className={cn(config.optionalFont, "text-text-03 shrink-0")}
style={{ height: config.lineHeight }}
>
(Optional)
{titleSuffix ?? "(Optional)"}
</span>
)}

View File

@@ -96,6 +96,8 @@ type MdContentProps = ContentBaseProps & {
variant?: "section";
/** When `true`, renders "(Optional)" beside the title in the muted font variant. */
optional?: boolean;
/** Custom muted suffix rendered beside the title. */
titleSuffix?: string;
/** Auxiliary status icon rendered beside the title. */
auxIcon?: "info-gray" | "info-blue" | "warning" | "error";
/** Tag rendered beside the title. */

View File

@@ -86,6 +86,26 @@ export interface IconProps extends SVGProps<SVGSVGElement> {
/** Strips `className` and `style` from a props type to enforce design-system styling. */
export type WithoutStyles<T> = Omit<T, "className" | "style">;
// ---------------------------------------------------------------------------
// Rich Strings
// ---------------------------------------------------------------------------
/**
* A branded string wrapper that signals inline markdown should be parsed.
*
* Created via the `markdown()` function. Components that accept `string | RichStr`
* will parse the inner `raw` string as inline markdown when a `RichStr` is passed,
* and render plain text when a regular `string` is passed.
*
* This avoids "API coloring" — components don't need a `markdown` boolean prop,
* and intermediate wrappers don't need to thread it through. The decision to
* use markdown lives at the call site via `markdown("*bold* text")`.
*/
export interface RichStr {
readonly __brand: "RichStr";
readonly raw: string;
}
/**
* HTML button `type` attribute values.
*

Some files were not shown because too many files have changed in this diff Show More