mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-17 07:26:45 +00:00
Compare commits
42 Commits
v3.2.0-clo
...
release/v3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1215ef4576 | ||
|
|
63d6f01895 | ||
|
|
8fc2b3c3de | ||
|
|
f5c48887f1 | ||
|
|
fe363bb62b | ||
|
|
9862b0ef59 | ||
|
|
8a7aeb2c59 | ||
|
|
648dcd1e47 | ||
|
|
f73796928c | ||
|
|
91101e8f2c | ||
|
|
44bb3ded44 | ||
|
|
493e3f23b8 | ||
|
|
031c1118bd | ||
|
|
b8b7702f28 | ||
|
|
ebb67aede9 | ||
|
|
340cd520eb | ||
|
|
b626ad232c | ||
|
|
f1ee9c12c0 | ||
|
|
378cbedaa1 | ||
|
|
f87e03b194 | ||
|
|
873636a095 | ||
|
|
efb194e067 | ||
|
|
3f7dfa7813 | ||
|
|
5f08af3678 | ||
|
|
1243af4f86 | ||
|
|
91e84b8278 | ||
|
|
1d6baf10db | ||
|
|
8d26357197 | ||
|
|
cd43345415 | ||
|
|
f99cf2f1b0 | ||
|
|
7332adb1e6 | ||
|
|
0ab1b76765 | ||
|
|
40cd0a78a3 | ||
|
|
28d8c5de46 | ||
|
|
004092767f | ||
|
|
eb4689a669 | ||
|
|
47dd8973c1 | ||
|
|
a1403ef78c | ||
|
|
f96b9d6804 | ||
|
|
711651276c | ||
|
|
3731110cf9 | ||
|
|
8fb7a8718e |
@@ -13,6 +13,7 @@ from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -107,12 +108,13 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
Get current seat usage directly from database.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users (excludes EXT_PERM_USER role
|
||||
and the anonymous system user).
|
||||
For self-hosted: counts all active users.
|
||||
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
Only human accounts count toward seat limits.
|
||||
SERVICE_ACCOUNT (API key dummy users), EXT_PERM_USER, and the
|
||||
anonymous system user are excluded. BOT (Slack users) ARE counted
|
||||
because they represent real humans and get upgraded to STANDARD
|
||||
when they log in via web.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
@@ -129,6 +131,7 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
User.email != ANONYMOUS_USER_EMAIL, # type: ignore
|
||||
User.account_type != AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -11,6 +11,8 @@ require a valid SCIM bearer token.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -22,6 +24,7 @@ from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -65,12 +68,25 @@ from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
# Namespace prefix for the seat-allocation advisory lock. Hashed together
|
||||
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
|
||||
# never block each other) and cannot collide with unrelated advisory locks.
|
||||
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
|
||||
|
||||
|
||||
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
|
||||
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
|
||||
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
|
||||
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
|
||||
return struct.unpack("q", digest[:8])[0]
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -209,12 +225,37 @@ def _apply_exclusions(
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
"""Return an error message if seat limit is reached, else None.
|
||||
|
||||
Acquires a transaction-scoped advisory lock so that concurrent
|
||||
SCIM requests are serialized. IdPs like Okta send provisioning
|
||||
requests in parallel batches — without serialization the check is
|
||||
vulnerable to a TOCTOU race where N concurrent requests each see
|
||||
"seats available", all insert, and the tenant ends up over its
|
||||
seat limit.
|
||||
|
||||
The lock is held until the caller's next COMMIT or ROLLBACK, which
|
||||
means the seat count cannot change between the check here and the
|
||||
subsequent INSERT/UPDATE. Each call site in this module follows
|
||||
the pattern: _check_seat_availability → write → dal.commit()
|
||||
(which releases the lock for the next waiting request).
|
||||
"""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)
|
||||
if check_fn is None:
|
||||
return None
|
||||
|
||||
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
|
||||
# The lock id is derived from the tenant so unrelated tenants never block
|
||||
# each other, and from a namespace string so it cannot collide with
|
||||
# unrelated advisory locks elsewhere in the codebase.
|
||||
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
|
||||
dal.session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(:lock_id)"),
|
||||
{"lock_id": lock_id},
|
||||
)
|
||||
|
||||
result = check_fn(dal.session, seats_needed=1)
|
||||
if not result.available:
|
||||
return result.error_message or "Seat limit reached"
|
||||
|
||||
@@ -10,6 +10,7 @@ from celery import bootsteps # type: ignore
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import before_task_publish
|
||||
from celery.signals import task_postrun
|
||||
from celery.signals import task_prerun
|
||||
from celery.states import READY_STATES
|
||||
@@ -94,6 +95,17 @@ class TenantAwareTask(Task):
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
@before_task_publish.connect
|
||||
def on_before_task_publish(
|
||||
headers: dict[str, Any] | None = None,
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Stamp the current wall-clock time into the task message headers so that
|
||||
workers can compute queue wait time (time between publish and execution)."""
|
||||
if headers is not None:
|
||||
headers["enqueued_at"] = time.time()
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None, # noqa: ARG001
|
||||
|
||||
@@ -16,6 +16,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -36,6 +42,7 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -50,6 +57,31 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -90,6 +122,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("light")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +38,12 @@ from onyx.redis.redis_connector_stop import RedisConnectorStop
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -59,6 +65,7 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -73,6 +80,31 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -212,6 +244,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("primary")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -59,6 +59,11 @@ from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_blocked
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_completed
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_fence_reset
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_started
|
||||
from onyx.server.metrics.deletion_metrics import observe_deletion_taskset_duration
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
@@ -300,6 +305,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
recent_index_attempts
|
||||
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
inc_deletion_blocked(tenant_id, "indexing")
|
||||
raise TaskDependencyError(
|
||||
"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -307,11 +313,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
)
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
inc_deletion_blocked(tenant_id, "pruning")
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
if redis_connector.permissions.fenced:
|
||||
inc_deletion_blocked(tenant_id, "permissions")
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (permissions in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
@@ -359,6 +367,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
# set this only after all tasks have been added
|
||||
fence_payload.num_tasks = tasks_generated
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
inc_deletion_started(tenant_id)
|
||||
|
||||
return tasks_generated
|
||||
|
||||
@@ -523,6 +532,12 @@ def monitor_connector_deletion_taskset(
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
duration = (
|
||||
datetime.now(timezone.utc) - fence_data.submitted
|
||||
).total_seconds()
|
||||
observe_deletion_taskset_duration(tenant_id, "success", duration)
|
||||
inc_deletion_completed(tenant_id, "success")
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
@@ -541,6 +556,11 @@ def monitor_connector_deletion_taskset(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
|
||||
)
|
||||
duration = (
|
||||
datetime.now(timezone.utc) - fence_data.submitted
|
||||
).total_seconds()
|
||||
observe_deletion_taskset_duration(tenant_id, "failure", duration)
|
||||
inc_deletion_completed(tenant_id, "failure")
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
@@ -717,5 +737,6 @@ def validate_connector_deletion_fence(
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
inc_deletion_fence_reset(tenant_id)
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
@@ -34,6 +34,7 @@ from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
@@ -467,6 +468,15 @@ def docfetching_proxy_task(
|
||||
index_attempt.connector_credential_pair.connector.source.value
|
||||
)
|
||||
|
||||
cc_pair = index_attempt.connector_credential_pair
|
||||
on_index_attempt_status_change(
|
||||
tenant_id=tenant_id,
|
||||
source=result.connector_source,
|
||||
cc_pair_id=cc_pair_id,
|
||||
connector_name=cc_pair.connector.name or f"cc_pair_{cc_pair_id}",
|
||||
status="in_progress",
|
||||
)
|
||||
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
|
||||
@@ -105,6 +105,9 @@ from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_utils import is_fence
|
||||
from onyx.server.metrics.connector_health_metrics import on_connector_error_state_change
|
||||
from onyx.server.metrics.connector_health_metrics import on_connector_indexing_success
|
||||
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
@@ -400,7 +403,6 @@ def check_indexing_completion(
|
||||
tenant_id: str,
|
||||
task: Task,
|
||||
) -> None:
|
||||
|
||||
logger.info(
|
||||
f"Checking for indexing completion: attempt={index_attempt_id} tenant={tenant_id}"
|
||||
)
|
||||
@@ -521,13 +523,25 @@ def check_indexing_completion(
|
||||
|
||||
# Update CC pair status if successful
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session, attempt.connector_credential_pair_id
|
||||
db_session,
|
||||
attempt.connector_credential_pair_id,
|
||||
eager_load_connector=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise RuntimeError(
|
||||
f"CC pair {attempt.connector_credential_pair_id} not found in database"
|
||||
)
|
||||
|
||||
source = cc_pair.connector.source.value
|
||||
connector_name = cc_pair.connector.name or f"cc_pair_{cc_pair.id}"
|
||||
on_index_attempt_status_change(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_name=connector_name,
|
||||
status=attempt.status.value,
|
||||
)
|
||||
|
||||
if attempt.status.is_successful():
|
||||
# NOTE: we define the last successful index time as the time the last successful
|
||||
# attempt finished. This is distinct from the poll_range_end of the last successful
|
||||
@@ -548,10 +562,26 @@ def check_indexing_completion(
|
||||
event=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
)
|
||||
|
||||
on_connector_indexing_success(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_name=connector_name,
|
||||
docs_indexed=attempt.new_docs_indexed or 0,
|
||||
success_timestamp=attempt.time_updated.timestamp(),
|
||||
)
|
||||
|
||||
# Clear repeated error state on success
|
||||
if cc_pair.in_repeated_error_state:
|
||||
cc_pair.in_repeated_error_state = False
|
||||
db_session.commit()
|
||||
on_connector_error_state_change(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_name=connector_name,
|
||||
in_error=False,
|
||||
)
|
||||
|
||||
if attempt.status == IndexingStatus.SUCCESS:
|
||||
logger.info(
|
||||
@@ -848,6 +878,16 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
cc_pair_id=cc_pair_id,
|
||||
in_repeated_error_state=True,
|
||||
)
|
||||
error_connector_name = (
|
||||
cc_pair.connector.name or f"cc_pair_{cc_pair.id}"
|
||||
)
|
||||
on_connector_error_state_change(
|
||||
tenant_id=tenant_id,
|
||||
source=cc_pair.connector.source.value,
|
||||
cc_pair_id=cc_pair_id,
|
||||
connector_name=error_connector_name,
|
||||
in_error=True,
|
||||
)
|
||||
# When entering repeated error state, also pause the connector
|
||||
# to prevent continued indexing retry attempts burning through embedding credits.
|
||||
# NOTE: only for Cloud, since most self-hosted users use self-hosted embedding
|
||||
|
||||
@@ -4,8 +4,6 @@ from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
@@ -635,7 +633,6 @@ def run_llm_loop(
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -1020,20 +1017,16 @@ def run_llm_loop(
|
||||
persisted_memory_id: int | None = None
|
||||
if user_memory_context and user_memory_context.user_id:
|
||||
if tool_response.rich_response.index_to_replace is not None:
|
||||
memory = update_memory_at_index(
|
||||
persisted_memory_id = update_memory_at_index(
|
||||
user_id=user_memory_context.user_id,
|
||||
index=tool_response.rich_response.index_to_replace,
|
||||
new_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id if memory else None
|
||||
else:
|
||||
memory = add_memory(
|
||||
persisted_memory_id = add_memory(
|
||||
user_id=user_memory_context.user_id,
|
||||
memory_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id
|
||||
operation: Literal["add", "update"] = (
|
||||
"update"
|
||||
if tool_response.rich_response.index_to_replace is not None
|
||||
|
||||
@@ -826,6 +826,12 @@ def translate_history_to_llm_format(
|
||||
base64_data = img_file.to_base64()
|
||||
image_url = f"data:{image_type};base64,{base64_data}"
|
||||
|
||||
content_parts.append(
|
||||
TextContentPart(
|
||||
type="text",
|
||||
text=f"[attached image — file_id: {img_file.file_id}]",
|
||||
)
|
||||
)
|
||||
image_part = ImageContentPart(
|
||||
type="image_url",
|
||||
image_url=ImageUrlDetail(
|
||||
|
||||
@@ -67,7 +67,6 @@ from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import reserve_multi_model_message_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
@@ -1006,93 +1005,86 @@ def _run_models(
|
||||
model_llm = setup.llms[model_idx]
|
||||
|
||||
try:
|
||||
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
|
||||
# Do NOT write to the outer db_session (or any shared DB state) from here;
|
||||
# all DB writes in this thread must go through thread_db_session.
|
||||
with get_session_with_current_tenant() as thread_db_session:
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
db_session=thread_db_session,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
# Each function opens short-lived DB sessions on demand.
|
||||
# Do NOT pass a long-lived session here — it would hold a
|
||||
# connection for the entire LLM loop (minutes), and cloud
|
||||
# infrastructure may drop idle connections.
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
)
|
||||
model_tools = [
|
||||
tool for tool_list in thread_tool_dict.values() for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
model_tools = [
|
||||
tool
|
||||
for tool_list in thread_tool_dict.values()
|
||||
for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError(
|
||||
"Deep research is not supported for projects"
|
||||
)
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError("Deep research is not supported for projects")
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
model_succeeded[model_idx] = True
|
||||
|
||||
|
||||
@@ -840,6 +840,29 @@ MAX_FILE_SIZE_BYTES = int(
|
||||
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
|
||||
) # 2GB in bytes
|
||||
|
||||
# Maximum embedded images allowed in a single file. PDFs (and other formats)
|
||||
# with thousands of embedded images can OOM the user-file-processing worker
|
||||
# because every image is decoded with PIL and then sent to the vision LLM.
|
||||
# Enforced both at upload time (rejects the file) and during extraction
|
||||
# (defense-in-depth: caps the number of images materialized).
|
||||
#
|
||||
# Clamped to >= 0; a negative env value would turn upload validation into
|
||||
# always-fail and extraction into always-stop, which is never desired. 0
|
||||
# disables image extraction entirely, which is a valid (if aggressive) setting.
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE = max(
|
||||
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_FILE") or 500)
|
||||
)
|
||||
|
||||
# Maximum embedded images allowed across all files in a single upload batch.
|
||||
# Protects against the scenario where a user uploads many files that each
|
||||
# fall under MAX_EMBEDDED_IMAGES_PER_FILE but aggregate to enough work
|
||||
# (serial-ish celery fan-out plus per-image vision-LLM calls) to OOM the
|
||||
# worker under concurrency or run up surprise latency/cost. Also clamped
|
||||
# to >= 0.
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD = max(
|
||||
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_UPLOAD") or 1000)
|
||||
)
|
||||
|
||||
# Use document summary for contextual rag
|
||||
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
|
||||
# Use chunk summary for contextual rag
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
from urllib.parse import urljoin
|
||||
@@ -10,7 +11,6 @@ from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from dateutil.parser import parse
|
||||
from dateutil.parser import ParserError
|
||||
|
||||
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -56,18 +56,16 @@ def time_str_to_utc(datetime_str: str) -> datetime:
|
||||
if fixed not in candidates:
|
||||
candidates.append(fixed)
|
||||
|
||||
last_exception: Exception | None = None
|
||||
for candidate in candidates:
|
||||
try:
|
||||
dt = parse(candidate)
|
||||
return datetime_to_utc(dt)
|
||||
except (ValueError, ParserError) as exc:
|
||||
last_exception = exc
|
||||
# dateutil is the primary; the stdlib RFC 2822 parser is a fallback for
|
||||
# inputs dateutil rejects (e.g. headers concatenated without a CRLF —
|
||||
# TZ may be dropped, datetime_to_utc then assumes UTC).
|
||||
for parser in (parse, parsedate_to_datetime):
|
||||
for candidate in candidates:
|
||||
try:
|
||||
return datetime_to_utc(parser(candidate))
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
continue
|
||||
|
||||
if last_exception is not None:
|
||||
raise last_exception
|
||||
|
||||
# Fallback in case parsing failed without raising (should not happen)
|
||||
raise ValueError(f"Unable to parse datetime string: {datetime_str}")
|
||||
|
||||
|
||||
|
||||
@@ -253,7 +253,17 @@ def thread_to_document(
|
||||
|
||||
updated_at_datetime = None
|
||||
if updated_at:
|
||||
updated_at_datetime = time_str_to_utc(updated_at)
|
||||
try:
|
||||
updated_at_datetime = time_str_to_utc(updated_at)
|
||||
except (ValueError, OverflowError) as e:
|
||||
# Old mailboxes contain RFC-violating Date headers. Drop the
|
||||
# timestamp instead of aborting the indexing run.
|
||||
logger.warning(
|
||||
"Skipping unparseable Gmail Date header on thread %s: %r (%s)",
|
||||
full_thread.get("id"),
|
||||
updated_at,
|
||||
e,
|
||||
)
|
||||
|
||||
id = full_thread.get("id")
|
||||
if not id:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
@@ -53,6 +54,21 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _load_google_json(raw: object) -> dict[str, Any]:
|
||||
"""Accept both the current (dict) and legacy (JSON string) KV payload shapes.
|
||||
|
||||
Payloads written before the fix for serializing Google credentials into
|
||||
``EncryptedJson`` columns are stored as JSON strings; new writes store dicts.
|
||||
Once every install has re-uploaded their Google credentials the legacy
|
||||
``str`` branch can be removed.
|
||||
"""
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
return json.loads(raw)
|
||||
raise ValueError(f"Unexpected Google credential payload type: {type(raw)!r}")
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
@@ -162,12 +178,13 @@ def build_service_account_creds(
|
||||
|
||||
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
credential_json = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
@@ -188,12 +205,12 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
|
||||
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
return GoogleAppCredentials(**creds)
|
||||
|
||||
|
||||
def upsert_google_app_cred(
|
||||
@@ -201,10 +218,14 @@ def upsert_google_app_cred(
|
||||
) -> None:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
KV_GOOGLE_DRIVE_CRED_KEY,
|
||||
app_credentials.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_CRED_KEY, app_credentials.model_dump(mode="json"), encrypt=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -220,12 +241,14 @@ def delete_google_app_cred(source: DocumentSource) -> None:
|
||||
|
||||
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
creds = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
return GoogleServiceAccountKey(**creds)
|
||||
|
||||
|
||||
def upsert_service_account_key(
|
||||
@@ -234,12 +257,14 @@ def upsert_service_account_key(
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.json(),
|
||||
service_account_key.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -60,8 +60,10 @@ logger = setup_logger()
|
||||
|
||||
ONE_HOUR = 3600
|
||||
|
||||
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
|
||||
_MAX_RESULTS_FETCH_IDS = 5000
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
|
||||
_JIRA_BULK_FETCH_LIMIT = 100
|
||||
|
||||
# Constants for Jira field names
|
||||
_FIELD_REPORTER = "reporter"
|
||||
@@ -255,15 +257,13 @@ def _bulk_fetch_request(
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
def _bulk_fetch_batch(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch a single batch (must be <= _JIRA_BULK_FETCH_LIMIT).
|
||||
On JSONDecodeError, recursively bisects until it succeeds or reaches size 1."""
|
||||
try:
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
return _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
@@ -277,12 +277,25 @@ def bulk_fetch_issues(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
left = _bulk_fetch_batch(jira_client, issue_ids[:mid], fields)
|
||||
right = _bulk_fetch_batch(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
raw_issues: list[dict[str, Any]] = []
|
||||
for batch in chunked(issue_ids, _JIRA_BULK_FETCH_LIMIT):
|
||||
try:
|
||||
raw_issues.extend(_bulk_fetch_batch(jira_client, list(batch), fields))
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
|
||||
@@ -6,6 +7,14 @@ from pydantic import BaseModel
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DirectThreadFetch:
|
||||
"""Request to fetch a Slack thread directly by channel and timestamp."""
|
||||
|
||||
channel_id: str
|
||||
thread_ts: str
|
||||
|
||||
|
||||
class ChannelMetadata(TypedDict):
|
||||
"""Type definition for cached channel metadata."""
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.models import SlackMessage
|
||||
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
|
||||
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
|
||||
@@ -49,7 +50,6 @@ from onyx.server.federated.models import FederatedConnectorDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -58,7 +58,6 @@ HIGHLIGHT_END_CHAR = "\ue001"
|
||||
|
||||
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
|
||||
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
|
||||
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
|
||||
|
||||
@@ -421,6 +420,94 @@ class SlackQueryResult(BaseModel):
|
||||
filtered_channels: list[str] # Channels filtered out during this query
|
||||
|
||||
|
||||
def _fetch_thread_from_url(
|
||||
thread_fetch: DirectThreadFetch,
|
||||
access_token: str,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
"""Fetch a thread directly from a Slack URL via conversations.replies."""
|
||||
channel_id = thread_fetch.channel_id
|
||||
thread_ts = thread_fetch.thread_ts
|
||||
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
response = slack_client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=thread_ts,
|
||||
)
|
||||
response.validate()
|
||||
messages: list[dict[str, Any]] = response.get("messages", [])
|
||||
except SlackApiError as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
if not messages:
|
||||
logger.warning(
|
||||
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
# Build thread text from all messages
|
||||
thread_text = _build_thread_text(messages, access_token, None, slack_client)
|
||||
|
||||
# Get channel name from metadata cache or API
|
||||
channel_name = "unknown"
|
||||
if channel_metadata_dict and channel_id in channel_metadata_dict:
|
||||
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
|
||||
else:
|
||||
try:
|
||||
ch_response = slack_client.conversations_info(channel=channel_id)
|
||||
ch_response.validate()
|
||||
channel_info: dict[str, Any] = ch_response.get("channel", {})
|
||||
channel_name = channel_info.get("name", "unknown")
|
||||
except SlackApiError:
|
||||
pass
|
||||
|
||||
# Build the SlackMessage
|
||||
parent_msg = messages[0]
|
||||
message_ts = parent_msg.get("ts", thread_ts)
|
||||
username = parent_msg.get("user", "unknown_user")
|
||||
parent_text = parent_msg.get("text", "")
|
||||
snippet = (
|
||||
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
|
||||
).replace("\n", " ")
|
||||
|
||||
doc_time = datetime.fromtimestamp(float(message_ts))
|
||||
decay_factor = DOC_TIME_DECAY
|
||||
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
|
||||
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
|
||||
|
||||
permalink = (
|
||||
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
|
||||
)
|
||||
|
||||
slack_message = SlackMessage(
|
||||
document_id=f"{channel_id}_{message_ts}",
|
||||
channel_id=channel_id,
|
||||
message_id=message_ts,
|
||||
thread_id=None, # Prevent double-enrichment in thread context fetch
|
||||
link=permalink,
|
||||
metadata={
|
||||
"channel": channel_name,
|
||||
"time": doc_time.isoformat(),
|
||||
},
|
||||
timestamp=doc_time,
|
||||
recency_bias=recency_bias,
|
||||
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
|
||||
text=thread_text,
|
||||
highlighted_texts=set(),
|
||||
slack_score=100000.0, # High priority — user explicitly asked for this thread
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
|
||||
)
|
||||
|
||||
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
|
||||
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
access_token: str,
|
||||
@@ -432,7 +519,6 @@ def query_slack(
|
||||
available_channels: list[str] | None = None,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
|
||||
# Check if query has channel override (user specified channels in query)
|
||||
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
|
||||
|
||||
@@ -662,7 +748,6 @@ def _fetch_thread_context(
|
||||
"""
|
||||
channel_id = message.channel_id
|
||||
thread_id = message.thread_id
|
||||
message_id = message.message_id
|
||||
|
||||
# If not a thread, return original text as success
|
||||
if thread_id is None:
|
||||
@@ -695,62 +780,37 @@ def _fetch_thread_context(
|
||||
if len(messages) <= 1:
|
||||
return ThreadContextResult.success(message.text)
|
||||
|
||||
# Build thread text from thread starter + context window around matched message
|
||||
thread_text = _build_thread_text(
|
||||
messages, message_id, thread_id, access_token, team_id, slack_client
|
||||
)
|
||||
# Build thread text from thread starter + all replies
|
||||
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
|
||||
return ThreadContextResult.success(thread_text)
|
||||
|
||||
|
||||
def _build_thread_text(
|
||||
messages: list[dict[str, Any]],
|
||||
message_id: str,
|
||||
thread_id: str,
|
||||
access_token: str,
|
||||
team_id: str | None,
|
||||
slack_client: WebClient,
|
||||
) -> str:
|
||||
"""Build the thread text from messages."""
|
||||
"""Build thread text including all replies.
|
||||
|
||||
Includes the thread parent message followed by all replies in order.
|
||||
"""
|
||||
msg_text = messages[0].get("text", "")
|
||||
msg_sender = messages[0].get("user", "")
|
||||
thread_text = f"<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# All messages after index 0 are replies
|
||||
replies = messages[1:]
|
||||
if not replies:
|
||||
return thread_text
|
||||
|
||||
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
|
||||
thread_text += "\n\nReplies:"
|
||||
if thread_id == message_id:
|
||||
message_id_idx = 0
|
||||
else:
|
||||
message_id_idx = next(
|
||||
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
|
||||
)
|
||||
if not message_id_idx:
|
||||
return thread_text
|
||||
|
||||
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
|
||||
|
||||
if start_idx > 1:
|
||||
thread_text += "\n..."
|
||||
|
||||
for i in range(start_idx, message_id_idx):
|
||||
msg_text = messages[i].get("text", "")
|
||||
msg_sender = messages[i].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
msg_text = messages[message_id_idx].get("text", "")
|
||||
msg_sender = messages[message_id_idx].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Add following replies
|
||||
len_replies = 0
|
||||
for msg in messages[message_id_idx + 1 :]:
|
||||
for msg in replies:
|
||||
msg_text = msg.get("text", "")
|
||||
msg_sender = msg.get("user", "")
|
||||
reply = f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
thread_text += reply
|
||||
|
||||
len_replies += len(reply)
|
||||
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
|
||||
thread_text += "\n..."
|
||||
break
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Replace user IDs with names using cached lookups
|
||||
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
|
||||
@@ -976,7 +1036,16 @@ def slack_retrieval(
|
||||
|
||||
# Query slack with entity filtering
|
||||
llm = get_default_llm()
|
||||
query_strings = build_slack_queries(query, llm, entities, available_channels)
|
||||
query_items = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Partition into direct thread fetches and search query strings
|
||||
direct_fetches: list[DirectThreadFetch] = []
|
||||
query_strings: list[str] = []
|
||||
for item in query_items:
|
||||
if isinstance(item, DirectThreadFetch):
|
||||
direct_fetches.append(item)
|
||||
else:
|
||||
query_strings.append(item)
|
||||
|
||||
# Determine filtering based on entities OR context (bot)
|
||||
include_dm = False
|
||||
@@ -993,8 +1062,16 @@ def slack_retrieval(
|
||||
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
|
||||
)
|
||||
|
||||
# Build search tasks
|
||||
search_tasks = [
|
||||
# Build search tasks — direct thread fetches + keyword searches
|
||||
search_tasks: list[tuple] = [
|
||||
(
|
||||
_fetch_thread_from_url,
|
||||
(fetch, access_token, channel_metadata_dict),
|
||||
)
|
||||
for fetch in direct_fetches
|
||||
]
|
||||
|
||||
search_tasks.extend(
|
||||
(
|
||||
query_slack,
|
||||
(
|
||||
@@ -1010,7 +1087,7 @@ def slack_retrieval(
|
||||
),
|
||||
)
|
||||
for query_string in query_strings
|
||||
]
|
||||
)
|
||||
|
||||
# If include_dm is True AND we're not already searching all channels,
|
||||
# add additional searches without channel filters.
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import ValidationError
|
||||
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -638,12 +639,38 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
return [query_text]
|
||||
|
||||
|
||||
SLACK_URL_PATTERN = re.compile(
|
||||
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
|
||||
)
|
||||
|
||||
|
||||
def extract_slack_message_urls(
|
||||
query_text: str,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Extract Slack message URLs from query text.
|
||||
|
||||
Parses URLs like:
|
||||
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
|
||||
|
||||
Returns list of (channel_id, thread_ts) tuples.
|
||||
The 16-digit timestamp is converted to Slack ts format (with dot).
|
||||
"""
|
||||
results = []
|
||||
for match in SLACK_URL_PATTERN.finditer(query_text):
|
||||
channel_id = match.group(1)
|
||||
raw_ts = match.group(2)
|
||||
# Convert p1775491616524769 -> 1775491616.524769
|
||||
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
|
||||
results.append((channel_id, thread_ts))
|
||||
return results
|
||||
|
||||
|
||||
def build_slack_queries(
|
||||
query: ChunkIndexRequest,
|
||||
llm: LLM,
|
||||
entities: dict[str, Any] | None = None,
|
||||
available_channels: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
) -> list[str | DirectThreadFetch]:
|
||||
"""Build Slack query strings with date filtering and query expansion."""
|
||||
default_search_days = 30
|
||||
if entities:
|
||||
@@ -668,6 +695,15 @@ def build_slack_queries(
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
|
||||
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
|
||||
|
||||
# Check for Slack message URLs — if found, add direct fetch requests
|
||||
url_fetches: list[DirectThreadFetch] = []
|
||||
slack_urls = extract_slack_message_urls(query.query)
|
||||
for channel_id, thread_ts in slack_urls:
|
||||
url_fetches.append(
|
||||
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
|
||||
)
|
||||
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
|
||||
|
||||
# ALWAYS extract channel references from the query (not just for recency queries)
|
||||
channel_references = extract_channel_references_from_query(query.query)
|
||||
|
||||
@@ -684,7 +720,9 @@ def build_slack_queries(
|
||||
|
||||
# If valid channels detected, use ONLY those channels with NO keywords
|
||||
# Return query with ONLY time filter + channel filter (no keywords)
|
||||
return [build_channel_override_query(channel_references, time_filter)]
|
||||
return url_fetches + [
|
||||
build_channel_override_query(channel_references, time_filter)
|
||||
]
|
||||
except ValueError as e:
|
||||
# If validation fails, log the error and continue with normal flow
|
||||
logger.warning(f"Channel reference validation failed: {e}")
|
||||
@@ -702,7 +740,8 @@ def build_slack_queries(
|
||||
rephrased_queries = expand_query_with_llm(query.query, llm)
|
||||
|
||||
# Build final query strings with time filters
|
||||
return [
|
||||
search_queries = [
|
||||
rephrased_query.strip() + time_filter
|
||||
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
|
||||
]
|
||||
return url_fetches + search_queries
|
||||
|
||||
@@ -750,31 +750,3 @@ def resync_cc_pair(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_connector_health_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list: # Returns list of Row tuples
|
||||
"""Return connector health data for Prometheus metrics.
|
||||
|
||||
Each row is (cc_pair_id, status, in_repeated_error_state,
|
||||
last_successful_index_time, name, source).
|
||||
"""
|
||||
return (
|
||||
db_session.query(
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.status,
|
||||
ConnectorCredentialPair.in_repeated_error_state,
|
||||
ConnectorCredentialPair.last_successful_index_time,
|
||||
ConnectorCredentialPair.name,
|
||||
Connector.source,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
@@ -346,6 +347,25 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _safe_close_session(session: Session) -> None:
|
||||
"""Close a session, catching connection-closed errors during cleanup.
|
||||
|
||||
Long-running operations (e.g. multi-model LLM loops) can hold a session
|
||||
open for minutes. If the underlying connection is dropped by cloud
|
||||
infrastructure (load-balancer timeouts, PgBouncer, idle-in-transaction
|
||||
timeouts, etc.), the implicit rollback in Session.close() raises
|
||||
OperationalError or InterfaceError. Since the work is already complete,
|
||||
we log and move on — SQLAlchemy internally invalidates the connection
|
||||
for pool recycling.
|
||||
"""
|
||||
try:
|
||||
session.close()
|
||||
except DBAPIError:
|
||||
logger.warning(
|
||||
"DB connection lost during session cleanup — the connection will be invalidated and recycled by the pool."
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
||||
"""
|
||||
@@ -358,8 +378,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
|
||||
# no need to use the schema translation map for self-hosted + default schema
|
||||
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
|
||||
with Session(bind=engine, expire_on_commit=False) as session:
|
||||
session = Session(bind=engine, expire_on_commit=False)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
return
|
||||
|
||||
# Create connection with schema translation to handle querying the right schema
|
||||
@@ -367,8 +390,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
with engine.connect().execution_options(
|
||||
schema_translate_map=schema_translate_map
|
||||
) as connection:
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
session = Session(bind=connection, expire_on_commit=False)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
|
||||
@@ -2,8 +2,6 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import NamedTuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -30,17 +28,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# from sqlalchemy.sql.selectable import Select
|
||||
|
||||
# Comment out unused imports that cause mypy errors
|
||||
# from onyx.auth.models import UserRole
|
||||
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
|
||||
# from onyx.db.connector_credential_pair import ConnectorCredentialPairIdentifier
|
||||
# from onyx.db.engine import async_query_for_dms
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -977,106 +964,3 @@ def get_index_attempt_errors_for_cc_pair(
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class ActiveIndexAttemptMetric(NamedTuple):
|
||||
"""Row returned by get_active_index_attempts_for_metrics."""
|
||||
|
||||
status: IndexingStatus
|
||||
source: "DocumentSource"
|
||||
cc_pair_id: int
|
||||
cc_pair_name: str | None
|
||||
attempt_count: int
|
||||
|
||||
|
||||
def get_active_index_attempts_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list[ActiveIndexAttemptMetric]:
|
||||
"""Return non-terminal index attempts grouped by status, source, and connector.
|
||||
|
||||
Each row is (status, source, cc_pair_id, cc_pair_name, attempt_count).
|
||||
"""
|
||||
from onyx.db.models import Connector
|
||||
|
||||
terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
func.count(),
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.filter(IndexAttempt.status.notin_(terminal_statuses))
|
||||
.group_by(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return [ActiveIndexAttemptMetric(*row) for row in rows]
|
||||
|
||||
|
||||
def get_failed_attempt_counts_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: failed_attempt_count} for all connectors.
|
||||
|
||||
When ``since`` is provided, only attempts created after that timestamp
|
||||
are counted. Defaults to the last 90 days to avoid unbounded historical
|
||||
aggregation.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.count(),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.FAILED)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.all()
|
||||
)
|
||||
return {cc_id: count for cc_id, count in rows}
|
||||
|
||||
|
||||
def get_docs_indexed_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: total_new_docs_indexed} across successful attempts.
|
||||
|
||||
Only counts attempts with status SUCCESS to avoid inflating counts with
|
||||
partial results from failed attempts. When ``since`` is provided, only
|
||||
attempts created after that timestamp are included.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
|
||||
query = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.sum(func.coalesce(IndexAttempt.new_docs_indexed, 0)),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.SUCCESS)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
)
|
||||
rows = query.all()
|
||||
return {cc_id: int(total or 0) for cc_id, total in rows}
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import ConfigDict
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -83,47 +84,51 @@ def get_memories(user: User, db_session: Session) -> UserMemoryContext:
|
||||
def add_memory(
|
||||
user_id: UUID,
|
||||
memory_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory:
|
||||
db_session: Session | None = None,
|
||||
) -> int:
|
||||
"""Insert a new Memory row for the given user.
|
||||
|
||||
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
|
||||
one (lowest id) is deleted before inserting the new one.
|
||||
|
||||
Returns the id of the newly created Memory row.
|
||||
"""
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
|
||||
|
||||
def update_memory_at_index(
|
||||
user_id: UUID,
|
||||
index: int,
|
||||
new_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory | None:
|
||||
db_session: Session | None = None,
|
||||
) -> int | None:
|
||||
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
|
||||
|
||||
Returns the updated Memory row, or None if the index is out of range.
|
||||
Returns the id of the updated Memory row, or None if the index is out of range.
|
||||
"""
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
|
||||
@@ -7,8 +7,6 @@ import time
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
@@ -22,6 +20,7 @@ from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
|
||||
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
|
||||
@@ -184,6 +183,14 @@ def generate_final_report(
|
||||
return has_reasoned
|
||||
|
||||
|
||||
def _get_research_agent_tool_id() -> int:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def run_deep_research_llm_loop(
|
||||
emitter: Emitter,
|
||||
@@ -193,7 +200,6 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt: str | None, # noqa: ARG001
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
skip_clarification: bool = False,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -717,6 +723,7 @@ def run_deep_research_llm_loop(
|
||||
simple_chat_history.append(assistant_with_tools)
|
||||
|
||||
# Now add TOOL_CALL_RESPONSE messages and tool call info for each result
|
||||
research_agent_tool_id = _get_research_agent_tool_id()
|
||||
for tab_index, report in enumerate(
|
||||
research_results.intermediate_reports
|
||||
):
|
||||
@@ -737,10 +744,7 @@ def run_deep_research_llm_loop(
|
||||
tab_index=tab_index,
|
||||
tool_name=current_tool_call.tool_name,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id,
|
||||
tool_id=research_agent_tool_id,
|
||||
reasoning_tokens=llm_step_result.reasoning
|
||||
or most_recent_reasoning,
|
||||
tool_call_arguments=current_tool_call.tool_args,
|
||||
|
||||
@@ -23,6 +23,7 @@ import openpyxl
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.constants import ONYX_METADATA_FILENAME
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
@@ -191,6 +192,56 @@ def read_text_file(
|
||||
return file_content_raw, metadata
|
||||
|
||||
|
||||
def count_pdf_embedded_images(file: IO[Any], cap: int) -> int:
|
||||
"""Return the number of embedded images in a PDF, short-circuiting at cap+1.
|
||||
|
||||
Used to reject PDFs whose image count would OOM the user-file-processing
|
||||
worker during indexing. Returns a value > cap as a sentinel once the count
|
||||
exceeds the cap, so callers do not iterate thousands of image objects just
|
||||
to report a number. Returns 0 if the PDF cannot be parsed.
|
||||
|
||||
Owner-password-only PDFs (permission restrictions but no open password) are
|
||||
counted normally — they decrypt with an empty string. Truly password-locked
|
||||
PDFs are skipped (return 0) since we can't inspect them; the caller should
|
||||
ensure the password-protected check runs first.
|
||||
|
||||
Always restores the file pointer to its original position before returning.
|
||||
"""
|
||||
from pypdf import PdfReader
|
||||
|
||||
try:
|
||||
start_pos = file.tell()
|
||||
except Exception:
|
||||
start_pos = None
|
||||
try:
|
||||
if start_pos is not None:
|
||||
file.seek(0)
|
||||
reader = PdfReader(file)
|
||||
if reader.is_encrypted:
|
||||
# Try empty password first (owner-password-only PDFs); give up if that fails.
|
||||
try:
|
||||
if reader.decrypt("") == 0:
|
||||
return 0
|
||||
except Exception:
|
||||
return 0
|
||||
count = 0
|
||||
for page in reader.pages:
|
||||
for _ in page.images:
|
||||
count += 1
|
||||
if count > cap:
|
||||
return count
|
||||
return count
|
||||
except Exception:
|
||||
logger.warning("Failed to count embedded images in PDF", exc_info=True)
|
||||
return 0
|
||||
finally:
|
||||
if start_pos is not None:
|
||||
try:
|
||||
file.seek(start_pos)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
|
||||
"""
|
||||
Extract text from a PDF. For embedded images, a more complex approach is needed.
|
||||
@@ -254,8 +305,27 @@ def read_pdf_file(
|
||||
)
|
||||
|
||||
if extract_images:
|
||||
image_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
images_processed = 0
|
||||
cap_reached = False
|
||||
for page_num, page in enumerate(pdf_reader.pages):
|
||||
if cap_reached:
|
||||
break
|
||||
for image_file_object in page.images:
|
||||
if images_processed >= image_cap:
|
||||
# Defense-in-depth backstop. Upload-time validation
|
||||
# should have rejected files exceeding the cap, but
|
||||
# we also break here so a single oversized file can
|
||||
# never pin a worker.
|
||||
logger.warning(
|
||||
"PDF embedded image cap reached (%d). "
|
||||
"Skipping remaining images on page %d and beyond.",
|
||||
image_cap,
|
||||
page_num + 1,
|
||||
)
|
||||
cap_reached = True
|
||||
break
|
||||
|
||||
image = Image.open(io.BytesIO(image_file_object.data))
|
||||
img_byte_arr = io.BytesIO()
|
||||
image.save(img_byte_arr, format=image.format)
|
||||
@@ -268,6 +338,7 @@ def read_pdf_file(
|
||||
image_callback(img_bytes, image_name)
|
||||
else:
|
||||
extracted_images.append((img_bytes, image_name))
|
||||
images_processed += 1
|
||||
|
||||
return text, metadata, extracted_images
|
||||
|
||||
|
||||
@@ -1516,6 +1516,10 @@
|
||||
"display_name": "Claude Opus 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-7": {
|
||||
"display_name": "Claude Opus 4.7",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-5-20251101": {
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
|
||||
@@ -46,6 +46,15 @@ ANTHROPIC_REASONING_EFFORT_BUDGET: dict[ReasoningEffort, int] = {
|
||||
ReasoningEffort.HIGH: 4096,
|
||||
}
|
||||
|
||||
# Newer Anthropic models (Claude Opus 4.7+) use adaptive thinking with
|
||||
# output_config.effort instead of thinking.type.enabled + budget_tokens.
|
||||
ANTHROPIC_ADAPTIVE_REASONING_EFFORT: dict[ReasoningEffort, str] = {
|
||||
ReasoningEffort.AUTO: "medium",
|
||||
ReasoningEffort.LOW: "low",
|
||||
ReasoningEffort.MEDIUM: "medium",
|
||||
ReasoningEffort.HIGH: "high",
|
||||
}
|
||||
|
||||
|
||||
# Content part structures for multimodal messages
|
||||
# The classes in this mirror the OpenAI Chat Completions message types and work well with routers like LiteLLM
|
||||
|
||||
@@ -23,6 +23,7 @@ from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.model_response import ModelResponse
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
from onyx.llm.model_response import Usage
|
||||
from onyx.llm.models import ANTHROPIC_ADAPTIVE_REASONING_EFFORT
|
||||
from onyx.llm.models import ANTHROPIC_REASONING_EFFORT_BUDGET
|
||||
from onyx.llm.models import OPENAI_REASONING_EFFORT
|
||||
from onyx.llm.request_context import get_llm_mock_response
|
||||
@@ -67,8 +68,13 @@ STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens"
|
||||
_VERTEX_ANTHROPIC_MODELS_REJECTING_OUTPUT_CONFIG = (
|
||||
"claude-opus-4-5",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-7",
|
||||
)
|
||||
|
||||
# Anthropic models that require the adaptive thinking API (thinking.type.adaptive
|
||||
# + output_config.effort) instead of the legacy thinking.type.enabled + budget_tokens.
|
||||
_ANTHROPIC_ADAPTIVE_THINKING_MODELS = ("claude-opus-4-7",)
|
||||
|
||||
|
||||
class LLMTimeoutError(Exception):
|
||||
"""
|
||||
@@ -230,6 +236,14 @@ def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _anthropic_uses_adaptive_thinking(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
adaptive_model in normalized_model_name
|
||||
for adaptive_model in _ANTHROPIC_ADAPTIVE_THINKING_MODELS
|
||||
)
|
||||
|
||||
|
||||
class LitellmLLM(LLM):
|
||||
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
|
||||
See https://python.langchain.com/docs/integrations/chat/litellm"""
|
||||
@@ -509,10 +523,6 @@ class LitellmLLM(LLM):
|
||||
}
|
||||
|
||||
elif is_claude_model:
|
||||
budget_tokens: int | None = ANTHROPIC_REASONING_EFFORT_BUDGET.get(
|
||||
reasoning_effort
|
||||
)
|
||||
|
||||
# Anthropic requires every assistant message with tool_use
|
||||
# blocks to start with a thinking block that carries a
|
||||
# cryptographic signature. We don't preserve those blocks
|
||||
@@ -520,24 +530,35 @@ class LitellmLLM(LLM):
|
||||
# contains tool-calling assistant messages. LiteLLM's
|
||||
# modify_params workaround doesn't cover all providers
|
||||
# (notably Bedrock).
|
||||
can_enable_thinking = (
|
||||
budget_tokens is not None
|
||||
and not _prompt_contains_tool_call_history(prompt)
|
||||
)
|
||||
has_tool_call_history = _prompt_contains_tool_call_history(prompt)
|
||||
|
||||
if can_enable_thinking:
|
||||
assert budget_tokens is not None # mypy
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
# Will note that overwriting a developer set max tokens is not ideal but is the best we can do for now
|
||||
# It is better to allow the LLM to output more reasoning tokens even if it results in a fairly small tool
|
||||
# call as compared to reducing the budget for reasoning.
|
||||
max_tokens = max(budget_tokens + 1, max_tokens)
|
||||
optional_kwargs["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": budget_tokens,
|
||||
}
|
||||
if _anthropic_uses_adaptive_thinking(self.config.model_name):
|
||||
# Newer Anthropic models (Claude Opus 4.7+) reject
|
||||
# thinking.type.enabled — they require the adaptive
|
||||
# thinking config with output_config.effort.
|
||||
if not has_tool_call_history:
|
||||
optional_kwargs["thinking"] = {"type": "adaptive"}
|
||||
optional_kwargs["output_config"] = {
|
||||
"effort": ANTHROPIC_ADAPTIVE_REASONING_EFFORT[
|
||||
reasoning_effort
|
||||
],
|
||||
}
|
||||
else:
|
||||
budget_tokens: int | None = ANTHROPIC_REASONING_EFFORT_BUDGET.get(
|
||||
reasoning_effort
|
||||
)
|
||||
if budget_tokens is not None and not has_tool_call_history:
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
# Will note that overwriting a developer set max tokens is not ideal but is the best we can do for now
|
||||
# It is better to allow the LLM to output more reasoning tokens even if it results in a fairly small tool
|
||||
# call as compared to reducing the budget for reasoning.
|
||||
max_tokens = max(budget_tokens + 1, max_tokens)
|
||||
optional_kwargs["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": budget_tokens,
|
||||
}
|
||||
|
||||
# LiteLLM just does some mapping like this anyway but is incomplete for Anthropic
|
||||
optional_kwargs.pop("reasoning_effort", None)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"version": "1.1",
|
||||
"updated_at": "2026-03-05T00:00:00Z",
|
||||
"version": "1.2",
|
||||
"updated_at": "2026-04-16T00:00:00Z",
|
||||
"providers": {
|
||||
"openai": {
|
||||
"default_model": { "name": "gpt-5.4" },
|
||||
@@ -10,8 +10,12 @@
|
||||
]
|
||||
},
|
||||
"anthropic": {
|
||||
"default_model": "claude-opus-4-6",
|
||||
"default_model": "claude-opus-4-7",
|
||||
"additional_visible_models": [
|
||||
{
|
||||
"name": "claude-opus-4-7",
|
||||
"display_name": "Claude Opus 4.7"
|
||||
},
|
||||
{
|
||||
"name": "claude-opus-4-6",
|
||||
"display_name": "Claude Opus 4.6"
|
||||
|
||||
@@ -65,8 +65,9 @@ IMPORTANT: each call to this tool is independent. Variables from previous calls
|
||||
GENERATE_IMAGE_GUIDANCE = """
|
||||
## generate_image
|
||||
NEVER use generate_image unless the user specifically requests an image.
|
||||
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
|
||||
the `file_id` values returned by earlier `generate_image` tool results.
|
||||
To edit, restyle, or vary an existing image, pass its file_id in `reference_image_file_ids`. \
|
||||
File IDs come from `[attached image — file_id: <id>]` tags on user-attached images or from prior `generate_image` tool results — never invent one. \
|
||||
Leave `reference_image_file_ids` unset for a fresh generation.
|
||||
""".lstrip()
|
||||
|
||||
MEMORY_GUIDANCE = """
|
||||
|
||||
@@ -40,6 +40,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.background.celery.versioned_apps.client import app as celery_app
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -51,6 +53,9 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILE_SIZE_BYTES
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILES_PER_UPLOAD
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_TOTAL_SIZE_BYTES
|
||||
@@ -128,6 +133,49 @@ class DeleteFileResponse(BaseModel):
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _looks_like_pdf(filename: str, content_type: str | None) -> bool:
|
||||
"""True if either the filename or the content-type indicates a PDF.
|
||||
|
||||
Client-supplied ``content_type`` can be spoofed (e.g. a PDF uploaded with
|
||||
``Content-Type: application/octet-stream``), so we also fall back to
|
||||
extension-based detection via ``mimetypes.guess_type`` on the filename.
|
||||
"""
|
||||
if content_type == "application/pdf":
|
||||
return True
|
||||
guessed, _ = mimetypes.guess_type(filename)
|
||||
return guessed == "application/pdf"
|
||||
|
||||
|
||||
def _check_pdf_image_caps(
|
||||
filename: str, content: bytes, content_type: str | None, batch_total: int
|
||||
) -> int:
|
||||
"""Enforce per-file and per-batch embedded-image caps for PDFs.
|
||||
|
||||
Returns the number of embedded images in this file (0 for non-PDFs) so
|
||||
callers can update their running batch total. Raises OnyxError(INVALID_INPUT)
|
||||
if either cap is exceeded.
|
||||
"""
|
||||
if not _looks_like_pdf(filename, content_type):
|
||||
return 0
|
||||
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
# Short-circuit at the larger cap so we get a useful count for both checks.
|
||||
count = count_pdf_embedded_images(BytesIO(content), max(file_cap, batch_cap))
|
||||
if count > file_cap:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"PDF '{filename}' contains too many embedded images "
|
||||
f"(more than {file_cap}). Try splitting the document into smaller files.",
|
||||
)
|
||||
if batch_total + count > batch_cap:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"Upload would exceed the {batch_cap}-image limit across all "
|
||||
f"files in this batch. Try uploading fewer image-heavy files at once.",
|
||||
)
|
||||
return count
|
||||
|
||||
|
||||
def _sanitize_path(path: str) -> str:
|
||||
"""Sanitize a file path, removing traversal attempts and normalizing.
|
||||
|
||||
@@ -356,6 +404,7 @@ async def upload_files(
|
||||
|
||||
uploaded_entries: list[LibraryEntryResponse] = []
|
||||
total_size = 0
|
||||
batch_image_total = 0
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Sanitize the base path
|
||||
@@ -375,6 +424,14 @@ async def upload_files(
|
||||
detail=f"File '{file.filename}' exceeds maximum size of {USER_LIBRARY_MAX_FILE_SIZE_BYTES // (1024 * 1024)}MB",
|
||||
)
|
||||
|
||||
# Reject PDFs with an unreasonable per-file or per-batch image count
|
||||
batch_image_total += _check_pdf_image_caps(
|
||||
filename=file.filename or "unnamed",
|
||||
content=content,
|
||||
content_type=file.content_type,
|
||||
batch_total=batch_image_total,
|
||||
)
|
||||
|
||||
# Validate cumulative storage (existing + this upload batch)
|
||||
total_size += file_size
|
||||
if existing_usage + total_size > USER_LIBRARY_MAX_TOTAL_SIZE_BYTES:
|
||||
@@ -473,6 +530,7 @@ async def upload_zip(
|
||||
|
||||
uploaded_entries: list[LibraryEntryResponse] = []
|
||||
total_size = 0
|
||||
batch_image_total = 0
|
||||
|
||||
# Extract zip contents into a subfolder named after the zip file
|
||||
zip_name = api_sanitize_filename(file.filename or "upload")
|
||||
@@ -511,6 +569,36 @@ async def upload_zip(
|
||||
logger.warning(f"Skipping '{zip_info.filename}' - exceeds max size")
|
||||
continue
|
||||
|
||||
# Skip PDFs that would trip the per-file or per-batch image
|
||||
# cap (would OOM the user-file-processing worker). Matches
|
||||
# /upload behavior but uses skip-and-warn to stay consistent
|
||||
# with the zip path's handling of oversized files.
|
||||
zip_file_name = zip_info.filename.split("/")[-1]
|
||||
zip_content_type, _ = mimetypes.guess_type(zip_file_name)
|
||||
if zip_content_type == "application/pdf":
|
||||
image_count = count_pdf_embedded_images(
|
||||
BytesIO(file_content),
|
||||
max(
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE,
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
|
||||
),
|
||||
)
|
||||
if image_count > MAX_EMBEDDED_IMAGES_PER_FILE:
|
||||
logger.warning(
|
||||
"Skipping '%s' - exceeds %d per-file embedded-image cap",
|
||||
zip_info.filename,
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE,
|
||||
)
|
||||
continue
|
||||
if batch_image_total + image_count > MAX_EMBEDDED_IMAGES_PER_UPLOAD:
|
||||
logger.warning(
|
||||
"Skipping '%s' - would exceed %d per-batch embedded-image cap",
|
||||
zip_info.filename,
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
|
||||
)
|
||||
continue
|
||||
batch_image_total += image_count
|
||||
|
||||
total_size += file_size
|
||||
|
||||
# Validate cumulative storage
|
||||
|
||||
@@ -618,6 +618,7 @@ done
|
||||
"app.kubernetes.io/managed-by": "onyx",
|
||||
"onyx.app/sandbox-id": sandbox_id,
|
||||
"onyx.app/tenant-id": tenant_id,
|
||||
"admission.datadoghq.com/enabled": "false",
|
||||
},
|
||||
),
|
||||
spec=pod_spec,
|
||||
|
||||
@@ -96,6 +96,32 @@ def _truncate_description(description: str | None, max_length: int = 500) -> str
|
||||
return description[: max_length - 3] + "..."
|
||||
|
||||
|
||||
# TODO: Replace mask-comparison approach with an explicit Unset sentinel from the
|
||||
# frontend indicating whether each credential field was actually modified. The current
|
||||
# approach is brittle (e.g. short credentials produce a fixed-length mask that could
|
||||
# collide) and mutates request values, which is surprising. The frontend should signal
|
||||
# "unchanged" vs "new value" directly rather than relying on masked-string equality.
|
||||
def _restore_masked_oauth_credentials(
|
||||
request_client_id: str | None,
|
||||
request_client_secret: str | None,
|
||||
existing_client: OAuthClientInformationFull,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""If the frontend sent back masked credentials, restore the real stored values."""
|
||||
if (
|
||||
request_client_id
|
||||
and existing_client.client_id
|
||||
and request_client_id == mask_string(existing_client.client_id)
|
||||
):
|
||||
request_client_id = existing_client.client_id
|
||||
if (
|
||||
request_client_secret
|
||||
and existing_client.client_secret
|
||||
and request_client_secret == mask_string(existing_client.client_secret)
|
||||
):
|
||||
request_client_secret = existing_client.client_secret
|
||||
return request_client_id, request_client_secret
|
||||
|
||||
|
||||
router = APIRouter(prefix="/mcp")
|
||||
admin_router = APIRouter(prefix="/admin/mcp")
|
||||
STATE_TTL_SECONDS = 60 * 5 # 5 minutes
|
||||
@@ -392,6 +418,26 @@ async def _connect_oauth(
|
||||
detail=f"Server was configured with authentication type {auth_type_str}",
|
||||
)
|
||||
|
||||
# If the frontend sent back masked credentials (unchanged by the user),
|
||||
# restore the real stored values so we don't overwrite them with masks.
|
||||
if mcp_server.admin_connection_config:
|
||||
existing_data = extract_connection_data(
|
||||
mcp_server.admin_connection_config, apply_mask=False
|
||||
)
|
||||
existing_client_raw = existing_data.get(MCPOAuthKeys.CLIENT_INFO.value)
|
||||
if existing_client_raw:
|
||||
existing_client = OAuthClientInformationFull.model_validate(
|
||||
existing_client_raw
|
||||
)
|
||||
(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
) = _restore_masked_oauth_credentials(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
existing_client,
|
||||
)
|
||||
|
||||
# Create admin config with client info if provided
|
||||
config_data = MCPConnectionData(headers={})
|
||||
if request.oauth_client_id and request.oauth_client_secret:
|
||||
@@ -1356,6 +1402,19 @@ def _upsert_mcp_server(
|
||||
if client_info_raw:
|
||||
client_info = OAuthClientInformationFull.model_validate(client_info_raw)
|
||||
|
||||
# If the frontend sent back masked credentials (unchanged by the user),
|
||||
# restore the real stored values so the comparison below sees no change
|
||||
# and the credentials aren't overwritten with masked strings.
|
||||
if client_info and request.auth_type == MCPAuthenticationType.OAUTH:
|
||||
(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
) = _restore_masked_oauth_credentials(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
client_info,
|
||||
)
|
||||
|
||||
changing_connection_config = (
|
||||
not mcp_server.admin_connection_config
|
||||
or (
|
||||
|
||||
@@ -11,6 +11,9 @@ from onyx.db.notification import dismiss_notification
|
||||
from onyx.db.notification import get_notification_by_id
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.server.features.build.utils import ensure_build_mode_intro_notification
|
||||
from onyx.server.features.notifications.utils import (
|
||||
ensure_permissions_migration_notification,
|
||||
)
|
||||
from onyx.server.features.release_notes.utils import (
|
||||
ensure_release_notes_fresh_and_notify,
|
||||
)
|
||||
@@ -49,6 +52,13 @@ def get_notifications_api(
|
||||
except Exception:
|
||||
logger.exception("Failed to check for release notes in notifications endpoint")
|
||||
|
||||
try:
|
||||
ensure_permissions_migration_notification(user, db_session)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to create permissions_migration_v1 announcement in notifications endpoint"
|
||||
)
|
||||
|
||||
notifications = [
|
||||
NotificationModel.from_model(notif)
|
||||
for notif in get_notifications(user, db_session, include_dismissed=True)
|
||||
|
||||
21
backend/onyx/server/features/notifications/utils.py
Normal file
21
backend/onyx/server/features/notifications/utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import create_notification
|
||||
|
||||
|
||||
def ensure_permissions_migration_notification(user: User, db_session: Session) -> None:
|
||||
# Feature id "permissions_migration_v1" must not change after shipping —
|
||||
# it is the dedup key on (user_id, notif_type, additional_data).
|
||||
create_notification(
|
||||
user_id=user.id,
|
||||
notif_type=NotificationType.FEATURE_ANNOUNCEMENT,
|
||||
db_session=db_session,
|
||||
title="Permissions are changing in Onyx",
|
||||
description="Roles are moving to group-based permissions. Click for details.",
|
||||
additional_data={
|
||||
"feature": "permissions_migration_v1",
|
||||
"link": "https://docs.onyx.app/admins/permissions/whats_changing",
|
||||
},
|
||||
)
|
||||
@@ -9,7 +9,10 @@ from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
@@ -190,6 +193,11 @@ def categorize_uploaded_files(
|
||||
token_threshold_k * 1000 if token_threshold_k else None
|
||||
) # 0 → None = no limit
|
||||
|
||||
# Running total of embedded images across PDFs in this batch. Once the
|
||||
# aggregate cap is reached, subsequent PDFs in the same upload are
|
||||
# rejected even if they'd individually fit under MAX_EMBEDDED_IMAGES_PER_FILE.
|
||||
batch_image_total = 0
|
||||
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
@@ -252,6 +260,47 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
# Reject PDFs with an unreasonable number of embedded images
|
||||
# (either per-file or accumulated across this upload batch).
|
||||
# A PDF with thousands of embedded images can OOM the
|
||||
# user-file-processing celery worker because every image is
|
||||
# decoded with PIL and then sent to the vision LLM.
|
||||
if extension == ".pdf":
|
||||
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
# Use the larger of the two caps as the short-circuit
|
||||
# threshold so we get a useful count for both checks.
|
||||
# count_pdf_embedded_images restores the stream position.
|
||||
count = count_pdf_embedded_images(
|
||||
upload.file, max(file_cap, batch_cap)
|
||||
)
|
||||
if count > file_cap:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=(
|
||||
f"PDF contains too many embedded images "
|
||||
f"(more than {file_cap}). Try splitting "
|
||||
f"the document into smaller files."
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
if batch_image_total + count > batch_cap:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=(
|
||||
f"Upload would exceed the "
|
||||
f"{batch_cap}-image limit across all "
|
||||
f"files in this batch. Try uploading "
|
||||
f"fewer image-heavy files at once."
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
batch_image_total += count
|
||||
|
||||
text_content = extract_file_text(
|
||||
file=upload.file,
|
||||
file_name=filename,
|
||||
|
||||
@@ -111,6 +111,43 @@ def _mask_string(value: str) -> str:
|
||||
return value[:4] + "****" + value[-4:]
|
||||
|
||||
|
||||
def _resolve_api_key(
|
||||
api_key: str | None,
|
||||
provider_name: str | None,
|
||||
api_base: str | None,
|
||||
db_session: Session,
|
||||
) -> str | None:
|
||||
"""Return the real API key for model-fetch endpoints.
|
||||
|
||||
When editing an existing provider the form value is masked (e.g.
|
||||
``sk-a****b1c2``). If *provider_name* is supplied we can look up
|
||||
the unmasked key from the database so the external request succeeds.
|
||||
|
||||
The stored key is only returned when the request's *api_base*
|
||||
matches the value stored in the database.
|
||||
"""
|
||||
if not provider_name:
|
||||
return api_key
|
||||
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.api_key:
|
||||
# Normalise both URLs before comparing so trailing-slash
|
||||
# differences don't cause a false mismatch.
|
||||
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
|
||||
request_base = (api_base or "").strip().rstrip("/")
|
||||
if stored_base != request_base:
|
||||
return api_key
|
||||
|
||||
stored_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
# Only resolve when the incoming value is the masked form of the
|
||||
# stored key — i.e. the user hasn't typed a new key.
|
||||
if api_key and api_key == _mask_string(stored_key):
|
||||
return stored_key
|
||||
return api_key
|
||||
|
||||
|
||||
def _sync_fetched_models(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
@@ -1174,16 +1211,17 @@ def get_ollama_available_models(
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str | None) -> dict:
|
||||
"""Perform GET to OpenRouter /models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/models"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
headers: dict[str, str] = {
|
||||
# Optional headers recommended by OpenRouter for attribution
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
@@ -1206,8 +1244,12 @@ def get_openrouter_available_models(
|
||||
Parses id, name (display), context_length, and architecture.input_modalities.
|
||||
"""
|
||||
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_openrouter_models_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
api_base=request.api_base, api_key=api_key
|
||||
)
|
||||
|
||||
data = response_json.get("data", [])
|
||||
@@ -1300,13 +1342,18 @@ def get_lm_studio_available_models(
|
||||
|
||||
# If provider_name is given and the api_key hasn't been changed by the user,
|
||||
# fall back to the stored API key from the database (the form value is masked).
|
||||
# Only do so when the api_base matches what is stored.
|
||||
api_key = request.api_key
|
||||
if request.provider_name and not request.api_key_changed:
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=request.provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.custom_config:
|
||||
api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY)
|
||||
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
|
||||
if stored_base == cleaned_api_base:
|
||||
api_key = existing_provider.custom_config.get(
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
)
|
||||
|
||||
url = f"{cleaned_api_base}/api/v1/models"
|
||||
headers: dict[str, str] = {}
|
||||
@@ -1390,8 +1437,12 @@ def get_litellm_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LitellmFinalModelResponse]:
|
||||
"""Fetch available models from Litellm proxy /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_litellm_models_response(
|
||||
api_key=request.api_key, api_base=request.api_base
|
||||
api_key=api_key, api_base=request.api_base
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
@@ -1448,7 +1499,7 @@ def get_litellm_available_models(
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
def _get_litellm_models_response(api_key: str | None, api_base: str) -> dict:
|
||||
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
@@ -1523,8 +1574,12 @@ def get_bifrost_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BifrostFinalModelResponse]:
|
||||
"""Fetch available models from Bifrost gateway /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_bifrost_models_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
api_base=request.api_base, api_key=api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
@@ -1613,8 +1668,12 @@ def get_openai_compatible_server_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[OpenAICompatibleFinalModelResponse]:
|
||||
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_openai_compatible_server_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
api_base=request.api_base, api_key=api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
|
||||
@@ -183,6 +183,9 @@ def generate_ollama_display_name(model_name: str) -> str:
|
||||
"qwen2.5:7b" → "Qwen 2.5 7B"
|
||||
"mistral:latest" → "Mistral"
|
||||
"deepseek-r1:14b" → "DeepSeek R1 14B"
|
||||
"gemma4:e4b" → "Gemma 4 E4B"
|
||||
"deepseek-v3.1:671b-cloud" → "DeepSeek V3.1 671B Cloud"
|
||||
"qwen3-vl:235b-instruct-cloud" → "Qwen 3-vl 235B Instruct Cloud"
|
||||
"""
|
||||
# Split into base name and tag
|
||||
if ":" in model_name:
|
||||
@@ -209,13 +212,24 @@ def generate_ollama_display_name(model_name: str) -> str:
|
||||
# Default: Title case with dashes converted to spaces
|
||||
display_name = base.replace("-", " ").title()
|
||||
|
||||
# Process tag to extract size info (skip "latest")
|
||||
# Process tag (skip "latest")
|
||||
if tag and tag.lower() != "latest":
|
||||
# Extract size like "7b", "70b", "14b"
|
||||
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])", tag)
|
||||
# Check for size prefix like "7b", "70b", optionally followed by modifiers
|
||||
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])(-.+)?$", tag)
|
||||
if size_match:
|
||||
size = size_match.group(1).upper()
|
||||
display_name = f"{display_name} {size}"
|
||||
remainder = size_match.group(2)
|
||||
if remainder:
|
||||
# Format modifiers like "-cloud", "-instruct-cloud"
|
||||
modifiers = " ".join(
|
||||
p.title() for p in remainder.strip("-").split("-") if p
|
||||
)
|
||||
display_name = f"{display_name} {size} {modifiers}"
|
||||
else:
|
||||
display_name = f"{display_name} {size}"
|
||||
else:
|
||||
# Non-size tags like "e4b", "q4_0", "fp16", "cloud"
|
||||
display_name = f"{display_name} {tag.upper()}"
|
||||
|
||||
return display_name
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import json
|
||||
import secrets
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
@@ -113,28 +114,47 @@ async def transcribe_audio(
|
||||
) from exc
|
||||
|
||||
|
||||
def _extract_provider_error(exc: Exception) -> str:
|
||||
"""Extract a human-readable message from a provider exception.
|
||||
|
||||
Provider errors often embed JSON from upstream APIs (e.g. ElevenLabs).
|
||||
This tries to parse a readable ``message`` field out of common JSON
|
||||
error shapes; falls back to ``str(exc)`` if nothing better is found.
|
||||
"""
|
||||
raw = str(exc)
|
||||
try:
|
||||
# Many providers embed JSON after a prefix like "ElevenLabs TTS failed: {...}"
|
||||
json_start = raw.find("{")
|
||||
if json_start == -1:
|
||||
return raw
|
||||
parsed = json.loads(raw[json_start:])
|
||||
# Shape: {"detail": {"message": "..."}} (ElevenLabs)
|
||||
detail = parsed.get("detail", parsed)
|
||||
if isinstance(detail, dict):
|
||||
return detail.get("message") or detail.get("error") or raw
|
||||
if isinstance(detail, str):
|
||||
return detail
|
||||
except (json.JSONDecodeError, AttributeError, TypeError):
|
||||
pass
|
||||
return raw
|
||||
|
||||
|
||||
class SynthesizeRequest(BaseModel):
|
||||
text: str = Field(..., min_length=1)
|
||||
voice: str | None = None
|
||||
speed: float | None = Field(default=None, ge=0.5, le=2.0)
|
||||
|
||||
|
||||
@router.post("/synthesize")
|
||||
async def synthesize_speech(
|
||||
text: str | None = Query(
|
||||
default=None, description="Text to synthesize", max_length=4096
|
||||
),
|
||||
voice: str | None = Query(default=None, description="Voice ID to use"),
|
||||
speed: float | None = Query(
|
||||
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
|
||||
),
|
||||
body: SynthesizeRequest,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Synthesize text to speech using the default TTS provider.
|
||||
|
||||
Accepts parameters via query string for streaming compatibility.
|
||||
"""
|
||||
logger.info(
|
||||
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
|
||||
)
|
||||
|
||||
if not text:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Text is required")
|
||||
"""Synthesize text to speech using the default TTS provider."""
|
||||
text = body.text
|
||||
voice = body.voice
|
||||
speed = body.speed
|
||||
logger.info(f"TTS request: text length={len(text)}, voice={voice}, speed={speed}")
|
||||
|
||||
# Use short-lived session to fetch provider config, then release connection
|
||||
# before starting the long-running streaming response
|
||||
@@ -177,31 +197,36 @@ async def synthesize_speech(
|
||||
logger.error(f"Failed to get voice provider: {exc}")
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
# Session is now closed - streaming response won't hold DB connection
|
||||
# Pull the first chunk before returning the StreamingResponse. If the
|
||||
# provider rejects the request (e.g. text too long), the error surfaces
|
||||
# as a proper HTTP error instead of a broken audio stream.
|
||||
stream_iter = provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
)
|
||||
try:
|
||||
first_chunk = await stream_iter.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "TTS provider returned no audio")
|
||||
except Exception as exc:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, _extract_provider_error(exc)
|
||||
) from exc
|
||||
|
||||
async def audio_stream() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
chunk_count = 0
|
||||
async for chunk in provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
):
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
except NotImplementedError as exc:
|
||||
logger.error(f"TTS not implemented: {exc}")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Synthesis failed: {exc}")
|
||||
raise
|
||||
yield first_chunk
|
||||
chunk_count = 1
|
||||
async for chunk in stream_iter:
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
|
||||
return StreamingResponse(
|
||||
audio_stream(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Content-Disposition": "inline; filename=speech.mp3",
|
||||
# Allow streaming by not setting content-length
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Generic Celery task lifecycle Prometheus metrics.
|
||||
|
||||
Provides signal handlers that track task started/completed/failed counts,
|
||||
active task gauge, task duration histograms, and retry/reject/revoke counts.
|
||||
active task gauge, task duration histograms, queue wait time histograms,
|
||||
and retry/reject/revoke counts.
|
||||
These fire for ALL tasks on the worker — no per-connector enrichment
|
||||
(see indexing_task_metrics.py for that).
|
||||
|
||||
@@ -71,6 +72,32 @@ TASK_REJECTED = Counter(
|
||||
["task_name"],
|
||||
)
|
||||
|
||||
TASK_QUEUE_WAIT = Histogram(
|
||||
"onyx_celery_task_queue_wait_seconds",
|
||||
"Time a Celery task spent waiting in the queue before execution started",
|
||||
["task_name", "queue"],
|
||||
buckets=[
|
||||
0.1,
|
||||
0.5,
|
||||
1,
|
||||
5,
|
||||
30,
|
||||
60,
|
||||
300,
|
||||
600,
|
||||
1800,
|
||||
3600,
|
||||
7200,
|
||||
14400,
|
||||
28800,
|
||||
43200,
|
||||
86400,
|
||||
172800,
|
||||
432000,
|
||||
864000,
|
||||
],
|
||||
)
|
||||
|
||||
# task_id → (monotonic start time, metric labels)
|
||||
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
|
||||
|
||||
@@ -133,6 +160,13 @@ def on_celery_task_prerun(
|
||||
with _task_start_times_lock:
|
||||
_evict_stale_start_times()
|
||||
_task_start_times[task_id] = (time.monotonic(), labels)
|
||||
|
||||
headers = getattr(task.request, "headers", None) or {}
|
||||
enqueued_at = headers.get("enqueued_at")
|
||||
if isinstance(enqueued_at, (int, float)):
|
||||
TASK_QUEUE_WAIT.labels(**labels).observe(
|
||||
max(0.0, time.time() - enqueued_at)
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task prerun metrics", exc_info=True)
|
||||
|
||||
|
||||
123
backend/onyx/server/metrics/connector_health_metrics.py
Normal file
123
backend/onyx/server/metrics/connector_health_metrics.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Prometheus metrics for connector health and index attempts.
|
||||
|
||||
Emitted by docfetching and docprocessing workers when connector or
|
||||
index attempt state changes. All functions silently catch exceptions
|
||||
to avoid disrupting the caller's business logic.
|
||||
|
||||
Gauge metrics (error state, last success timestamp) are per-process.
|
||||
With multiple worker pods, use max() aggregation in PromQL to get the
|
||||
correct value across instances, e.g.:
|
||||
max by (cc_pair_id, connector_name) (onyx_connector_in_error_state)
|
||||
|
||||
Unlike the per-task counters in indexing_task_metrics.py, these metrics
|
||||
include connector_name because their cardinality is bounded by the number
|
||||
of connectors (one series per connector), not by the number of task
|
||||
executions.
|
||||
"""
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_CONNECTOR_LABELS = ["tenant_id", "source", "cc_pair_id", "connector_name"]
|
||||
|
||||
# --- Index attempt lifecycle ---
|
||||
|
||||
INDEX_ATTEMPT_STATUS = Counter(
|
||||
"onyx_index_attempt_transitions_total",
|
||||
"Index attempt status transitions",
|
||||
[*_CONNECTOR_LABELS, "status"],
|
||||
)
|
||||
|
||||
# --- Connector health ---
|
||||
|
||||
CONNECTOR_IN_ERROR_STATE = Gauge(
|
||||
"onyx_connector_in_error_state",
|
||||
"Whether the connector is in a repeated error state (1=yes, 0=no)",
|
||||
_CONNECTOR_LABELS,
|
||||
)
|
||||
|
||||
CONNECTOR_LAST_SUCCESS_TIMESTAMP = Gauge(
|
||||
"onyx_connector_last_success_timestamp_seconds",
|
||||
"Unix timestamp of last successful indexing for this connector",
|
||||
_CONNECTOR_LABELS,
|
||||
)
|
||||
|
||||
CONNECTOR_DOCS_INDEXED = Counter(
|
||||
"onyx_connector_docs_indexed_total",
|
||||
"Total documents indexed per connector (monotonic)",
|
||||
_CONNECTOR_LABELS,
|
||||
)
|
||||
|
||||
CONNECTOR_INDEXING_ERRORS = Counter(
|
||||
"onyx_connector_indexing_errors_total",
|
||||
"Total failed index attempts per connector (monotonic)",
|
||||
_CONNECTOR_LABELS,
|
||||
)
|
||||
|
||||
|
||||
def on_index_attempt_status_change(
|
||||
tenant_id: str,
|
||||
source: str,
|
||||
cc_pair_id: int,
|
||||
connector_name: str,
|
||||
status: str,
|
||||
) -> None:
|
||||
"""Called on any index attempt status transition."""
|
||||
try:
|
||||
labels = {
|
||||
"tenant_id": tenant_id,
|
||||
"source": source,
|
||||
"cc_pair_id": str(cc_pair_id),
|
||||
"connector_name": connector_name,
|
||||
}
|
||||
INDEX_ATTEMPT_STATUS.labels(**labels, status=status).inc()
|
||||
if status == "failed":
|
||||
CONNECTOR_INDEXING_ERRORS.labels(**labels).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record index attempt status metric", exc_info=True)
|
||||
|
||||
|
||||
def on_connector_error_state_change(
|
||||
tenant_id: str,
|
||||
source: str,
|
||||
cc_pair_id: int,
|
||||
connector_name: str,
|
||||
in_error: bool,
|
||||
) -> None:
|
||||
"""Called when a connector's in_repeated_error_state changes."""
|
||||
try:
|
||||
CONNECTOR_IN_ERROR_STATE.labels(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=str(cc_pair_id),
|
||||
connector_name=connector_name,
|
||||
).set(1.0 if in_error else 0.0)
|
||||
except Exception:
|
||||
logger.debug("Failed to record connector error state metric", exc_info=True)
|
||||
|
||||
|
||||
def on_connector_indexing_success(
|
||||
tenant_id: str,
|
||||
source: str,
|
||||
cc_pair_id: int,
|
||||
connector_name: str,
|
||||
docs_indexed: int,
|
||||
success_timestamp: float,
|
||||
) -> None:
|
||||
"""Called when an indexing run completes successfully."""
|
||||
try:
|
||||
labels = {
|
||||
"tenant_id": tenant_id,
|
||||
"source": source,
|
||||
"cc_pair_id": str(cc_pair_id),
|
||||
"connector_name": connector_name,
|
||||
}
|
||||
CONNECTOR_LAST_SUCCESS_TIMESTAMP.labels(**labels).set(success_timestamp)
|
||||
if docs_indexed > 0:
|
||||
CONNECTOR_DOCS_INDEXED.labels(**labels).inc(docs_indexed)
|
||||
except Exception:
|
||||
logger.debug("Failed to record connector success metric", exc_info=True)
|
||||
104
backend/onyx/server/metrics/deletion_metrics.py
Normal file
104
backend/onyx/server/metrics/deletion_metrics.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Connector-deletion-specific Prometheus metrics.
|
||||
|
||||
Tracks the deletion lifecycle:
|
||||
1. Deletions started (taskset generated)
|
||||
2. Deletions completed (success or failure)
|
||||
3. Taskset duration (from taskset generation to completion or failure).
|
||||
Note: this measures the most recent taskset execution, NOT wall-clock
|
||||
time since the user triggered the deletion. When deletion is blocked by
|
||||
indexing/pruning/permissions, the fence is cleared and a fresh taskset
|
||||
is generated on each retry, resetting this timer.
|
||||
4. Deletion blocked by dependencies (indexing, pruning, permissions, etc.)
|
||||
5. Fence resets (stuck deletion recovery)
|
||||
|
||||
All metrics are labeled by tenant_id. cc_pair_id is intentionally excluded
|
||||
to avoid unbounded cardinality.
|
||||
|
||||
Usage:
|
||||
from onyx.server.metrics.deletion_metrics import (
|
||||
inc_deletion_started,
|
||||
inc_deletion_completed,
|
||||
observe_deletion_taskset_duration,
|
||||
inc_deletion_blocked,
|
||||
inc_deletion_fence_reset,
|
||||
)
|
||||
"""
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DELETION_STARTED = Counter(
|
||||
"onyx_deletion_started_total",
|
||||
"Connector deletions initiated (taskset generated)",
|
||||
["tenant_id"],
|
||||
)
|
||||
|
||||
DELETION_COMPLETED = Counter(
|
||||
"onyx_deletion_completed_total",
|
||||
"Connector deletions completed",
|
||||
["tenant_id", "outcome"],
|
||||
)
|
||||
|
||||
DELETION_TASKSET_DURATION = Histogram(
|
||||
"onyx_deletion_taskset_duration_seconds",
|
||||
"Duration of a connector deletion taskset, from taskset generation "
|
||||
"to completion or failure. Does not include time spent blocked on "
|
||||
"indexing/pruning/permissions before the taskset was generated.",
|
||||
["tenant_id", "outcome"],
|
||||
buckets=[10, 30, 60, 120, 300, 600, 1800, 3600, 7200, 21600],
|
||||
)
|
||||
|
||||
DELETION_BLOCKED = Counter(
|
||||
"onyx_deletion_blocked_total",
|
||||
"Times deletion was blocked by a dependency",
|
||||
["tenant_id", "blocker"],
|
||||
)
|
||||
|
||||
DELETION_FENCE_RESET = Counter(
|
||||
"onyx_deletion_fence_reset_total",
|
||||
"Deletion fences reset due to missing celery tasks",
|
||||
["tenant_id"],
|
||||
)
|
||||
|
||||
|
||||
def inc_deletion_started(tenant_id: str) -> None:
|
||||
try:
|
||||
DELETION_STARTED.labels(tenant_id=tenant_id).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion started", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_completed(tenant_id: str, outcome: str) -> None:
|
||||
try:
|
||||
DELETION_COMPLETED.labels(tenant_id=tenant_id, outcome=outcome).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion completed", exc_info=True)
|
||||
|
||||
|
||||
def observe_deletion_taskset_duration(
|
||||
tenant_id: str, outcome: str, duration_seconds: float
|
||||
) -> None:
|
||||
try:
|
||||
DELETION_TASKSET_DURATION.labels(tenant_id=tenant_id, outcome=outcome).observe(
|
||||
duration_seconds
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion taskset duration", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_blocked(tenant_id: str, blocker: str) -> None:
|
||||
try:
|
||||
DELETION_BLOCKED.labels(tenant_id=tenant_id, blocker=blocker).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion blocked", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_fence_reset(tenant_id: str) -> None:
|
||||
try:
|
||||
DELETION_FENCE_RESET.labels(tenant_id=tenant_id).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion fence reset", exc_info=True)
|
||||
@@ -1,25 +1,30 @@
|
||||
"""Prometheus collectors for Celery queue depths and indexing pipeline state.
|
||||
"""Prometheus collectors for Celery queue depths and infrastructure health.
|
||||
|
||||
These collectors query Redis and Postgres at scrape time (the Collector pattern),
|
||||
These collectors query Redis at scrape time (the Collector pattern),
|
||||
so metrics are always fresh when Prometheus scrapes /metrics. They run inside the
|
||||
monitoring celery worker which already has Redis and DB access.
|
||||
monitoring celery worker which already has Redis access.
|
||||
|
||||
To avoid hammering Redis/Postgres on every 15s scrape, results are cached with
|
||||
To avoid hammering Redis on every 15s scrape, results are cached with
|
||||
a configurable TTL (default 30s). This means metrics may be up to TTL seconds
|
||||
stale, which is fine for monitoring dashboards.
|
||||
|
||||
Note: connector health and index attempt metrics are push-based (emitted by
|
||||
workers at state-change time) and live in connector_health_metrics.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from prometheus_client.core import GaugeMetricFamily
|
||||
from prometheus_client.registry import Collector
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -31,6 +36,11 @@ logger = setup_logger()
|
||||
# the previous result without re-querying Redis/Postgres.
|
||||
_DEFAULT_CACHE_TTL = 30.0
|
||||
|
||||
# Maximum time (seconds) a single _collect_fresh() call may take before
|
||||
# the collector gives up and returns stale/empty results. Prevents the
|
||||
# /metrics endpoint from hanging indefinitely when a DB or Redis query stalls.
|
||||
_DEFAULT_COLLECT_TIMEOUT = 120.0
|
||||
|
||||
_QUEUE_LABEL_MAP: dict[str, str] = {
|
||||
OnyxCeleryQueues.PRIMARY: "primary",
|
||||
OnyxCeleryQueues.DOCPROCESSING: "docprocessing",
|
||||
@@ -62,18 +72,32 @@ _UNACKED_QUEUES: list[str] = [
|
||||
|
||||
|
||||
class _CachedCollector(Collector):
|
||||
"""Base collector with TTL-based caching.
|
||||
"""Base collector with TTL-based caching and timeout protection.
|
||||
|
||||
Subclasses implement ``_collect_fresh()`` to query the actual data source.
|
||||
The base ``collect()`` returns cached results if the TTL hasn't expired,
|
||||
avoiding repeated queries when Prometheus scrapes frequently.
|
||||
|
||||
A per-collection timeout prevents a slow DB or Redis query from blocking
|
||||
the /metrics endpoint indefinitely. If _collect_fresh() exceeds the
|
||||
timeout, stale cached results are returned instead.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
cache_ttl: float = _DEFAULT_CACHE_TTL,
|
||||
collect_timeout: float = _DEFAULT_COLLECT_TIMEOUT,
|
||||
) -> None:
|
||||
self._cache_ttl = cache_ttl
|
||||
self._collect_timeout = collect_timeout
|
||||
self._cached_result: list[GaugeMetricFamily] | None = None
|
||||
self._last_collect_time: float = 0.0
|
||||
self._lock = threading.Lock()
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=1,
|
||||
thread_name_prefix=type(self).__name__,
|
||||
)
|
||||
self._inflight: concurrent.futures.Future | None = None
|
||||
|
||||
def collect(self) -> list[GaugeMetricFamily]:
|
||||
with self._lock:
|
||||
@@ -84,12 +108,28 @@ class _CachedCollector(Collector):
|
||||
):
|
||||
return self._cached_result
|
||||
|
||||
# If a previous _collect_fresh() is still running, wait on it
|
||||
# rather than queuing another. This prevents unbounded task
|
||||
# accumulation in the executor during extended DB outages.
|
||||
if self._inflight is not None and not self._inflight.done():
|
||||
future = self._inflight
|
||||
else:
|
||||
future = self._executor.submit(self._collect_fresh)
|
||||
self._inflight = future
|
||||
|
||||
try:
|
||||
result = self._collect_fresh()
|
||||
result = future.result(timeout=self._collect_timeout)
|
||||
self._inflight = None
|
||||
self._cached_result = result
|
||||
self._last_collect_time = now
|
||||
return result
|
||||
except concurrent.futures.TimeoutError:
|
||||
logger.warning(
|
||||
f"{type(self).__name__}._collect_fresh() timed out after {self._collect_timeout}s, returning stale cache"
|
||||
)
|
||||
return self._cached_result if self._cached_result is not None else []
|
||||
except Exception:
|
||||
self._inflight = None
|
||||
logger.exception(f"Error in {type(self).__name__}.collect()")
|
||||
# Return stale cache on error rather than nothing — avoids
|
||||
# metrics disappearing during transient failures.
|
||||
@@ -117,8 +157,6 @@ class QueueDepthCollector(_CachedCollector):
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
depth = GaugeMetricFamily(
|
||||
@@ -194,208 +232,6 @@ class QueueDepthCollector(_CachedCollector):
|
||||
return None
|
||||
|
||||
|
||||
class IndexAttemptCollector(_CachedCollector):
|
||||
"""Queries Postgres for index attempt state on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._configured: bool = False
|
||||
self._terminal_statuses: list = []
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Call once DB engine is initialized."""
|
||||
from onyx.db.enums import IndexingStatus
|
||||
|
||||
self._terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
self._configured = True
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if not self._configured:
|
||||
return []
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.index_attempt import get_active_index_attempts_for_metrics
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
attempts_gauge = GaugeMetricFamily(
|
||||
"onyx_index_attempts_active",
|
||||
"Number of non-terminal index attempts",
|
||||
labels=[
|
||||
"status",
|
||||
"source",
|
||||
"tenant_id",
|
||||
"connector_name",
|
||||
"cc_pair_id",
|
||||
],
|
||||
)
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tid in tenant_ids:
|
||||
# Defensive guard — get_all_tenant_ids() should never yield None,
|
||||
# but we guard here for API stability in case the contract changes.
|
||||
if tid is None:
|
||||
continue
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
rows = get_active_index_attempts_for_metrics(session)
|
||||
|
||||
for status, source, cc_id, cc_name, count in rows:
|
||||
name_val = cc_name or f"cc_pair_{cc_id}"
|
||||
attempts_gauge.add_metric(
|
||||
[
|
||||
status.value,
|
||||
source.value,
|
||||
tid,
|
||||
name_val,
|
||||
str(cc_id),
|
||||
],
|
||||
count,
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return [attempts_gauge]
|
||||
|
||||
|
||||
class ConnectorHealthCollector(_CachedCollector):
|
||||
"""Queries Postgres for connector health state on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._configured: bool = False
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Call once DB engine is initialized."""
|
||||
self._configured = True
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if not self._configured:
|
||||
return []
|
||||
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_health_for_metrics,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.index_attempt import get_docs_indexed_by_cc_pair
|
||||
from onyx.db.index_attempt import get_failed_attempt_counts_by_cc_pair
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
staleness_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_last_success_age_seconds",
|
||||
"Seconds since last successful index for this connector",
|
||||
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
|
||||
)
|
||||
error_state_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_in_error_state",
|
||||
"Whether the connector is in a repeated error state (1=yes, 0=no)",
|
||||
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
|
||||
)
|
||||
by_status_gauge = GaugeMetricFamily(
|
||||
"onyx_connectors_by_status",
|
||||
"Number of connectors grouped by status",
|
||||
labels=["tenant_id", "status"],
|
||||
)
|
||||
error_total_gauge = GaugeMetricFamily(
|
||||
"onyx_connectors_in_error_total",
|
||||
"Total number of connectors in repeated error state",
|
||||
labels=["tenant_id"],
|
||||
)
|
||||
per_connector_labels = [
|
||||
"tenant_id",
|
||||
"source",
|
||||
"cc_pair_id",
|
||||
"connector_name",
|
||||
]
|
||||
docs_success_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_docs_indexed",
|
||||
"Total new documents indexed (90-day rolling sum) per connector",
|
||||
labels=per_connector_labels,
|
||||
)
|
||||
docs_error_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_error_count",
|
||||
"Total number of failed index attempts per connector",
|
||||
labels=per_connector_labels,
|
||||
)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tid in tenant_ids:
|
||||
# Defensive guard — get_all_tenant_ids() should never yield None,
|
||||
# but we guard here for API stability in case the contract changes.
|
||||
if tid is None:
|
||||
continue
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
pairs = get_connector_health_for_metrics(session)
|
||||
error_counts_by_cc = get_failed_attempt_counts_by_cc_pair(session)
|
||||
docs_by_cc = get_docs_indexed_by_cc_pair(session)
|
||||
|
||||
status_counts: dict[str, int] = {}
|
||||
error_count = 0
|
||||
|
||||
for (
|
||||
cc_id,
|
||||
status,
|
||||
in_error,
|
||||
last_success,
|
||||
cc_name,
|
||||
source,
|
||||
) in pairs:
|
||||
cc_id_str = str(cc_id)
|
||||
source_val = source.value
|
||||
name_val = cc_name or f"cc_pair_{cc_id}"
|
||||
label_vals = [tid, source_val, cc_id_str, name_val]
|
||||
|
||||
if last_success is not None:
|
||||
# Both `now` and `last_success` are timezone-aware
|
||||
# (the DB column uses DateTime(timezone=True)),
|
||||
# so subtraction is safe.
|
||||
age = (now - last_success).total_seconds()
|
||||
staleness_gauge.add_metric(label_vals, age)
|
||||
|
||||
error_state_gauge.add_metric(
|
||||
label_vals,
|
||||
1.0 if in_error else 0.0,
|
||||
)
|
||||
if in_error:
|
||||
error_count += 1
|
||||
|
||||
docs_success_gauge.add_metric(
|
||||
label_vals,
|
||||
docs_by_cc.get(cc_id, 0),
|
||||
)
|
||||
|
||||
docs_error_gauge.add_metric(
|
||||
label_vals,
|
||||
error_counts_by_cc.get(cc_id, 0),
|
||||
)
|
||||
|
||||
status_val = status.value
|
||||
status_counts[status_val] = status_counts.get(status_val, 0) + 1
|
||||
|
||||
for status_val, count in status_counts.items():
|
||||
by_status_gauge.add_metric([tid, status_val], count)
|
||||
|
||||
error_total_gauge.add_metric([tid], error_count)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return [
|
||||
staleness_gauge,
|
||||
error_state_gauge,
|
||||
by_status_gauge,
|
||||
error_total_gauge,
|
||||
docs_success_gauge,
|
||||
docs_error_gauge,
|
||||
]
|
||||
|
||||
|
||||
class RedisHealthCollector(_CachedCollector):
|
||||
"""Collects Redis server health metrics (memory, clients, etc.)."""
|
||||
|
||||
@@ -411,8 +247,6 @@ class RedisHealthCollector(_CachedCollector):
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
memory_used = GaugeMetricFamily(
|
||||
@@ -495,7 +329,9 @@ class WorkerHeartbeatMonitor:
|
||||
},
|
||||
)
|
||||
recv.capture(
|
||||
limit=None, timeout=self._HEARTBEAT_TIMEOUT_SECONDS, wakeup=True
|
||||
limit=None,
|
||||
timeout=self._HEARTBEAT_TIMEOUT_SECONDS,
|
||||
wakeup=True,
|
||||
)
|
||||
except Exception:
|
||||
if self._running:
|
||||
|
||||
@@ -6,8 +6,6 @@ Called once by the monitoring celery worker after Redis and DB are ready.
|
||||
from celery import Celery
|
||||
from prometheus_client.registry import REGISTRY
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
|
||||
@@ -21,8 +19,6 @@ logger = setup_logger()
|
||||
# module level ensures they survive the lifetime of the worker process and are
|
||||
# only registered with the Prometheus registry once.
|
||||
_queue_collector = QueueDepthCollector()
|
||||
_attempt_collector = IndexAttemptCollector()
|
||||
_connector_collector = ConnectorHealthCollector()
|
||||
_redis_health_collector = RedisHealthCollector()
|
||||
_worker_health_collector = WorkerHealthCollector()
|
||||
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
|
||||
@@ -34,6 +30,9 @@ def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
Args:
|
||||
celery_app: The Celery application instance. Used to obtain a
|
||||
broker Redis client on each scrape for queue depth metrics.
|
||||
|
||||
Note: connector health and index attempt metrics are push-based
|
||||
(see connector_health_metrics.py) and do not use collectors.
|
||||
"""
|
||||
_queue_collector.set_celery_app(celery_app)
|
||||
_redis_health_collector.set_celery_app(celery_app)
|
||||
@@ -47,13 +46,8 @@ def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
_heartbeat_monitor.start()
|
||||
_worker_health_collector.set_monitor(_heartbeat_monitor)
|
||||
|
||||
_attempt_collector.configure()
|
||||
_connector_collector.configure()
|
||||
|
||||
for collector in (
|
||||
_queue_collector,
|
||||
_attempt_collector,
|
||||
_connector_collector,
|
||||
_redis_health_collector,
|
||||
_worker_health_collector,
|
||||
):
|
||||
|
||||
@@ -27,6 +27,8 @@ _DEFAULT_PORTS: dict[str, int] = {
|
||||
"docfetching": 9092,
|
||||
"docprocessing": 9093,
|
||||
"heavy": 9094,
|
||||
"light": 9095,
|
||||
"primary": 9097,
|
||||
}
|
||||
|
||||
_server_started = False
|
||||
|
||||
@@ -28,14 +28,14 @@ PRUNING_ENUMERATION_DURATION = Histogram(
|
||||
"onyx_pruning_enumeration_duration_seconds",
|
||||
"Duration of document ID enumeration from the source connector during pruning",
|
||||
["connector_type"],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
buckets=[5, 60, 600, 1800, 3600, 10800, 21600],
|
||||
)
|
||||
|
||||
PRUNING_DIFF_DURATION = Histogram(
|
||||
"onyx_pruning_diff_duration_seconds",
|
||||
"Duration of diff computation and subtask dispatch during pruning",
|
||||
["connector_type"],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
buckets=[0.1, 0.25, 0.5, 1, 2, 5, 15, 30, 60],
|
||||
)
|
||||
|
||||
PRUNING_RATE_LIMIT_ERRORS = Counter(
|
||||
|
||||
@@ -65,6 +65,7 @@ class Settings(BaseModel):
|
||||
anonymous_user_enabled: bool | None = None
|
||||
invite_only_enabled: bool = False
|
||||
deep_research_enabled: bool | None = None
|
||||
multi_model_chat_enabled: bool | None = None
|
||||
search_ui_enabled: bool | None = None
|
||||
|
||||
# Whether EE features are unlocked for use.
|
||||
@@ -89,7 +90,8 @@ class Settings(BaseModel):
|
||||
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
|
||||
)
|
||||
file_token_count_threshold_k: int | None = Field(
|
||||
default=None, ge=0 # thousands of tokens; None = context-aware default
|
||||
default=None,
|
||||
ge=0, # thousands of tokens; None = context-aware default
|
||||
)
|
||||
|
||||
# Connector settings
|
||||
|
||||
@@ -208,12 +208,6 @@ class PythonToolOverrideKwargs(BaseModel):
|
||||
chat_files: list[ChatFile] = []
|
||||
|
||||
|
||||
class ImageGenerationToolOverrideKwargs(BaseModel):
|
||||
"""Override kwargs for image generation tool calls."""
|
||||
|
||||
recent_generated_image_file_ids: list[str] = []
|
||||
|
||||
|
||||
class SearchToolRunContext(BaseModel):
|
||||
emitter: Emitter
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.mcp import get_all_mcp_tools_for_server
|
||||
@@ -113,10 +114,10 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
|
||||
|
||||
def construct_tools(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
emitter: Emitter,
|
||||
user: User,
|
||||
llm: LLM,
|
||||
db_session: Session | None = None,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
file_reader_tool_config: FileReaderToolConfig | None = None,
|
||||
@@ -131,6 +132,33 @@ def construct_tools(
|
||||
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
|
||||
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
|
||||
to avoid lazy SQL queries after the session may have been flushed."""
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
return _construct_tools_impl(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
custom_tool_config=custom_tool_config,
|
||||
file_reader_tool_config=file_reader_tool_config,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
search_usage_forcing_setting=search_usage_forcing_setting,
|
||||
)
|
||||
|
||||
|
||||
def _construct_tools_impl(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
emitter: Emitter,
|
||||
user: User,
|
||||
llm: LLM,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
file_reader_tool_config: FileReaderToolConfig | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
search_usage_forcing_setting: SearchToolUsage = SearchToolUsage.AUTO,
|
||||
) -> dict[int, list[Tool]]:
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
# Log which tools are attached to the persona for debugging
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeart
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ImageGenerationToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.models import ToolExecutionException
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -48,7 +47,7 @@ PROMPT_FIELD = "prompt"
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD = "reference_image_file_ids"
|
||||
|
||||
|
||||
class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
class ImageGenerationTool(Tool[None]):
|
||||
NAME = "generate_image"
|
||||
DESCRIPTION = "Generate an image based on a prompt. Do not use unless the user specifically requests an image."
|
||||
DISPLAY_NAME = "Image Generation"
|
||||
@@ -142,8 +141,11 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD: {
|
||||
"type": "array",
|
||||
"description": (
|
||||
"Optional image file IDs to use as reference context for edits/variations. "
|
||||
"Use the file_id values returned by previous generate_image calls."
|
||||
"Optional file_ids of existing images to edit or use as reference;"
|
||||
" the first is the primary edit source."
|
||||
" Get file_ids from `[attached image — file_id: <id>]` tags on"
|
||||
" user-attached images or from prior generate_image tool responses."
|
||||
" Omit for a fresh, unrelated generation."
|
||||
),
|
||||
"items": {
|
||||
"type": "string",
|
||||
@@ -254,41 +256,31 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
def _resolve_reference_image_file_ids(
|
||||
self,
|
||||
llm_kwargs: dict[str, Any],
|
||||
override_kwargs: ImageGenerationToolOverrideKwargs | None,
|
||||
) -> list[str]:
|
||||
raw_reference_ids = llm_kwargs.get(REFERENCE_IMAGE_FILE_IDS_FIELD)
|
||||
if raw_reference_ids is not None:
|
||||
if not isinstance(raw_reference_ids, list) or not all(
|
||||
isinstance(file_id, str) for file_id in raw_reference_ids
|
||||
):
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, got {type(raw_reference_ids)}"
|
||||
),
|
||||
llm_facing_message=(
|
||||
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
|
||||
),
|
||||
)
|
||||
reference_image_file_ids = [
|
||||
file_id.strip() for file_id in raw_reference_ids if file_id.strip()
|
||||
]
|
||||
elif (
|
||||
override_kwargs
|
||||
and override_kwargs.recent_generated_image_file_ids
|
||||
and self.img_provider.supports_reference_images
|
||||
):
|
||||
# If no explicit reference was provided, default to the most recently generated image.
|
||||
reference_image_file_ids = [
|
||||
override_kwargs.recent_generated_image_file_ids[-1]
|
||||
]
|
||||
else:
|
||||
reference_image_file_ids = []
|
||||
if raw_reference_ids is None:
|
||||
# No references requested — plain generation.
|
||||
return []
|
||||
|
||||
# Deduplicate while preserving order.
|
||||
if not isinstance(raw_reference_ids, list) or not all(
|
||||
isinstance(file_id, str) for file_id in raw_reference_ids
|
||||
):
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, got {type(raw_reference_ids)}"
|
||||
),
|
||||
llm_facing_message=(
|
||||
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
|
||||
),
|
||||
)
|
||||
|
||||
# Deduplicate while preserving order (first occurrence wins, so the
|
||||
# LLM's intended "primary edit source" stays at index 0).
|
||||
deduped_reference_image_ids: list[str] = []
|
||||
seen_ids: set[str] = set()
|
||||
for file_id in reference_image_file_ids:
|
||||
if file_id in seen_ids:
|
||||
for file_id in raw_reference_ids:
|
||||
file_id = file_id.strip()
|
||||
if not file_id or file_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(file_id)
|
||||
deduped_reference_image_ids.append(file_id)
|
||||
@@ -302,14 +294,14 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
f"Reference images requested but provider '{self.provider}' does not support image-editing context."
|
||||
),
|
||||
llm_facing_message=(
|
||||
"This image provider does not support editing from previous image context. "
|
||||
"This image provider does not support editing from existing images. "
|
||||
"Try text-only generation, or switch to a provider/model that supports image edits."
|
||||
),
|
||||
)
|
||||
|
||||
max_reference_images = self.img_provider.max_reference_images
|
||||
if max_reference_images > 0:
|
||||
return deduped_reference_image_ids[-max_reference_images:]
|
||||
return deduped_reference_image_ids[:max_reference_images]
|
||||
return deduped_reference_image_ids
|
||||
|
||||
def _load_reference_images(
|
||||
@@ -358,7 +350,7 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
override_kwargs: ImageGenerationToolOverrideKwargs | None = None,
|
||||
override_kwargs: None = None, # noqa: ARG002
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
if PROMPT_FIELD not in llm_kwargs:
|
||||
@@ -373,7 +365,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
shape = ImageShape(llm_kwargs.get("shape", ImageShape.SQUARE.value))
|
||||
reference_image_file_ids = self._resolve_reference_image_file_ids(
|
||||
llm_kwargs=llm_kwargs,
|
||||
override_kwargs=override_kwargs,
|
||||
)
|
||||
reference_images = self._load_reference_images(reference_image_file_ids)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
@@ -14,7 +13,6 @@ from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import ChatMinimalTextMessage
|
||||
from onyx.tools.models import ImageGenerationToolOverrideKwargs
|
||||
from onyx.tools.models import OpenURLToolOverrideKwargs
|
||||
from onyx.tools.models import ParallelToolCallResponse
|
||||
from onyx.tools.models import PythonToolOverrideKwargs
|
||||
@@ -24,9 +22,6 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolExecutionException
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.models import WebSearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
@@ -110,63 +105,6 @@ def _merge_tool_calls(tool_calls: list[ToolCallKickoff]) -> list[ToolCallKickoff
|
||||
return merged_calls
|
||||
|
||||
|
||||
def _extract_image_file_ids_from_tool_response_message(
|
||||
message: str,
|
||||
) -> list[str]:
|
||||
try:
|
||||
parsed_message = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
parsed_items: list[Any] = (
|
||||
parsed_message if isinstance(parsed_message, list) else [parsed_message]
|
||||
)
|
||||
file_ids: list[str] = []
|
||||
for item in parsed_items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
file_id = item.get("file_id")
|
||||
if isinstance(file_id, str):
|
||||
file_ids.append(file_id)
|
||||
|
||||
return file_ids
|
||||
|
||||
|
||||
def _extract_recent_generated_image_file_ids(
|
||||
message_history: list[ChatMessageSimple],
|
||||
) -> list[str]:
|
||||
tool_name_by_tool_call_id: dict[str, str] = {}
|
||||
recent_image_file_ids: list[str] = []
|
||||
seen_file_ids: set[str] = set()
|
||||
|
||||
for message in message_history:
|
||||
if message.message_type == MessageType.ASSISTANT and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name_by_tool_call_id[tool_call.tool_call_id] = tool_call.tool_name
|
||||
continue
|
||||
|
||||
if (
|
||||
message.message_type != MessageType.TOOL_CALL_RESPONSE
|
||||
or not message.tool_call_id
|
||||
):
|
||||
continue
|
||||
|
||||
tool_name = tool_name_by_tool_call_id.get(message.tool_call_id)
|
||||
if tool_name != ImageGenerationTool.NAME:
|
||||
continue
|
||||
|
||||
for file_id in _extract_image_file_ids_from_tool_response_message(
|
||||
message.message
|
||||
):
|
||||
if file_id in seen_file_ids:
|
||||
continue
|
||||
seen_file_ids.add(file_id)
|
||||
recent_image_file_ids.append(file_id)
|
||||
|
||||
return recent_image_file_ids
|
||||
|
||||
|
||||
def _safe_run_single_tool(
|
||||
tool: Tool,
|
||||
tool_call: ToolCallKickoff,
|
||||
@@ -386,9 +324,6 @@ def run_tool_calls(
|
||||
url_to_citation: dict[str, int] = {
|
||||
url: citation_num for citation_num, url in citation_mapping.items()
|
||||
}
|
||||
recent_generated_image_file_ids = _extract_recent_generated_image_file_ids(
|
||||
message_history
|
||||
)
|
||||
|
||||
# Prepare all tool calls with their override_kwargs
|
||||
# Each tool gets a unique starting citation number to avoid conflicts when running in parallel
|
||||
@@ -405,7 +340,6 @@ def run_tool_calls(
|
||||
| WebSearchToolOverrideKwargs
|
||||
| OpenURLToolOverrideKwargs
|
||||
| PythonToolOverrideKwargs
|
||||
| ImageGenerationToolOverrideKwargs
|
||||
| MemoryToolOverrideKwargs
|
||||
| None
|
||||
) = None
|
||||
@@ -454,10 +388,6 @@ def run_tool_calls(
|
||||
override_kwargs = PythonToolOverrideKwargs(
|
||||
chat_files=chat_files or [],
|
||||
)
|
||||
elif isinstance(tool, ImageGenerationTool):
|
||||
override_kwargs = ImageGenerationToolOverrideKwargs(
|
||||
recent_generated_image_file_ids=recent_generated_image_file_ids
|
||||
)
|
||||
elif isinstance(tool, MemoryTool):
|
||||
override_kwargs = MemoryToolOverrideKwargs(
|
||||
user_name=(
|
||||
|
||||
@@ -38,38 +38,41 @@ class TestAddMemory:
|
||||
def test_add_memory_creates_row(self, db_session: Session, test_user: User) -> None:
|
||||
"""Verify that add_memory inserts a new Memory row."""
|
||||
user_id = test_user.id
|
||||
memory = add_memory(
|
||||
memory_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="User prefers dark mode",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert memory.id is not None
|
||||
assert memory.user_id == user_id
|
||||
assert memory.memory_text == "User prefers dark mode"
|
||||
assert memory_id is not None
|
||||
|
||||
# Verify it persists
|
||||
fetched = db_session.get(Memory, memory.id)
|
||||
fetched = db_session.get(Memory, memory_id)
|
||||
assert fetched is not None
|
||||
assert fetched.user_id == user_id
|
||||
assert fetched.memory_text == "User prefers dark mode"
|
||||
|
||||
def test_add_multiple_memories(self, db_session: Session, test_user: User) -> None:
|
||||
"""Verify that multiple memories can be added for the same user."""
|
||||
user_id = test_user.id
|
||||
m1 = add_memory(
|
||||
m1_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Favorite color is blue",
|
||||
db_session=db_session,
|
||||
)
|
||||
m2 = add_memory(
|
||||
m2_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Works in engineering",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert m1.id != m2.id
|
||||
assert m1.memory_text == "Favorite color is blue"
|
||||
assert m2.memory_text == "Works in engineering"
|
||||
assert m1_id != m2_id
|
||||
fetched_m1 = db_session.get(Memory, m1_id)
|
||||
fetched_m2 = db_session.get(Memory, m2_id)
|
||||
assert fetched_m1 is not None
|
||||
assert fetched_m2 is not None
|
||||
assert fetched_m1.memory_text == "Favorite color is blue"
|
||||
assert fetched_m2.memory_text == "Works in engineering"
|
||||
|
||||
|
||||
class TestUpdateMemoryAtIndex:
|
||||
@@ -82,15 +85,17 @@ class TestUpdateMemoryAtIndex:
|
||||
add_memory(user_id=user_id, memory_text="Memory 1", db_session=db_session)
|
||||
add_memory(user_id=user_id, memory_text="Memory 2", db_session=db_session)
|
||||
|
||||
updated = update_memory_at_index(
|
||||
updated_id = update_memory_at_index(
|
||||
user_id=user_id,
|
||||
index=1,
|
||||
new_text="Updated Memory 1",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert updated is not None
|
||||
assert updated.memory_text == "Updated Memory 1"
|
||||
assert updated_id is not None
|
||||
fetched = db_session.get(Memory, updated_id)
|
||||
assert fetched is not None
|
||||
assert fetched.memory_text == "Updated Memory 1"
|
||||
|
||||
def test_update_memory_at_out_of_range_index(
|
||||
self, db_session: Session, test_user: User
|
||||
@@ -167,7 +172,7 @@ class TestMemoryCap:
|
||||
assert len(rows_before) == MAX_MEMORIES_PER_USER
|
||||
|
||||
# Add one more — should evict the oldest
|
||||
new_memory = add_memory(
|
||||
new_memory_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="New memory after cap",
|
||||
db_session=db_session,
|
||||
@@ -181,7 +186,7 @@ class TestMemoryCap:
|
||||
# Oldest ("Memory 0") should be gone; "Memory 1" is now the oldest
|
||||
assert rows_after[0].memory_text == "Memory 1"
|
||||
# Newest should be the one we just added
|
||||
assert rows_after[-1].id == new_memory.id
|
||||
assert rows_after[-1].id == new_memory_id
|
||||
assert rows_after[-1].memory_text == "New memory after cap"
|
||||
|
||||
|
||||
@@ -221,22 +226,26 @@ class TestGetMemoriesWithUserId:
|
||||
user_id = test_user_no_memories.id
|
||||
|
||||
# Add a memory
|
||||
memory = add_memory(
|
||||
memory_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Memory with use_memories off",
|
||||
db_session=db_session,
|
||||
)
|
||||
assert memory.memory_text == "Memory with use_memories off"
|
||||
fetched = db_session.get(Memory, memory_id)
|
||||
assert fetched is not None
|
||||
assert fetched.memory_text == "Memory with use_memories off"
|
||||
|
||||
# Update that memory
|
||||
updated = update_memory_at_index(
|
||||
updated_id = update_memory_at_index(
|
||||
user_id=user_id,
|
||||
index=0,
|
||||
new_text="Updated memory with use_memories off",
|
||||
db_session=db_session,
|
||||
)
|
||||
assert updated is not None
|
||||
assert updated.memory_text == "Updated memory with use_memories off"
|
||||
assert updated_id is not None
|
||||
fetched_updated = db_session.get(Memory, updated_id)
|
||||
assert fetched_updated is not None
|
||||
assert fetched_updated.memory_text == "Updated memory with use_memories off"
|
||||
|
||||
# Verify get_memories returns the updated memory
|
||||
context = get_memories(test_user_no_memories, db_session)
|
||||
|
||||
@@ -9,6 +9,7 @@ from unittest.mock import patch
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from ee.onyx.db.license import delete_license
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.db.license import get_used_seats
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
@@ -214,3 +215,43 @@ class TestCheckSeatAvailabilityMultiTenant:
|
||||
assert result.available is False
|
||||
assert result.error_message is not None
|
||||
mock_tenant_count.assert_called_once_with("tenant-abc")
|
||||
|
||||
|
||||
class TestGetUsedSeatsAccountTypeFiltering:
|
||||
"""Verify get_used_seats query excludes SERVICE_ACCOUNT but includes BOT."""
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", False)
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_excludes_service_accounts(self, mock_get_session: MagicMock) -> None:
|
||||
"""SERVICE_ACCOUNT users should not count toward seats."""
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session.execute.return_value.scalar.return_value = 5
|
||||
|
||||
result = get_used_seats()
|
||||
|
||||
assert result == 5
|
||||
# Inspect the compiled query to verify account_type filter
|
||||
call_args = mock_session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
compiled = str(query.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "SERVICE_ACCOUNT" in compiled
|
||||
# BOT should NOT be excluded
|
||||
assert "BOT" not in compiled
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", False)
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_still_excludes_ext_perm_user(self, mock_get_session: MagicMock) -> None:
|
||||
"""EXT_PERM_USER exclusion should still be present."""
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session.execute.return_value.scalar.return_value = 3
|
||||
|
||||
get_used_seats()
|
||||
|
||||
call_args = mock_session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
compiled = str(query.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "EXT_PERM_USER" in compiled
|
||||
|
||||
@@ -301,7 +301,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -332,7 +331,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -363,7 +361,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -391,7 +388,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -423,7 +419,6 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -456,7 +451,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -497,7 +491,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -519,7 +512,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -542,7 +534,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -596,7 +587,6 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
@@ -653,7 +643,6 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
@@ -706,7 +695,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -736,7 +724,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
|
||||
|
||||
def test_time_str_to_utc() -> None:
|
||||
str_to_dt = {
|
||||
"Tue, 5 Oct 2021 09:38:25 GMT": datetime.datetime(
|
||||
2021, 10, 5, 9, 38, 25, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"Sat, 24 Jul 2021 09:21:20 +0000 (UTC)": datetime.datetime(
|
||||
2021, 7, 24, 9, 21, 20, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"Thu, 29 Jul 2021 04:20:37 -0400 (EDT)": datetime.datetime(
|
||||
2021, 7, 29, 8, 20, 37, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"30 Jun 2023 18:45:01 +0300": datetime.datetime(
|
||||
2023, 6, 30, 15, 45, 1, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"22 Mar 2020 20:12:18 +0000 (GMT)": datetime.datetime(
|
||||
2020, 3, 22, 20, 12, 18, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"Date: Wed, 27 Aug 2025 11:40:00 +0200": datetime.datetime(
|
||||
2025, 8, 27, 9, 40, 0, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
for strptime, expected_datetime in str_to_dt.items():
|
||||
assert time_str_to_utc(strptime) == expected_datetime
|
||||
|
||||
|
||||
def test_time_str_to_utc_recovers_from_concatenated_headers() -> None:
|
||||
# TZ is dropped during recovery, so the expected result is UTC rather
|
||||
# than the original offset.
|
||||
assert time_str_to_utc(
|
||||
'Sat, 3 Nov 2007 14:33:28 -0200To: "jason" <jason@example.net>'
|
||||
) == datetime.datetime(2007, 11, 3, 14, 33, 28, tzinfo=datetime.timezone.utc)
|
||||
|
||||
assert time_str_to_utc(
|
||||
"Fri, 20 Feb 2015 10:30:00 +0500Cc: someone@example.com"
|
||||
) == datetime.datetime(2015, 2, 20, 10, 30, 0, tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
def test_time_str_to_utc_raises_on_impossible_dates() -> None:
|
||||
for bad in (
|
||||
"Wed, 33 Sep 2007 13:42:59 +0100",
|
||||
"Thu, 11 Oct 2007 31:50:55 +0900",
|
||||
"not a date at all",
|
||||
"",
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
time_str_to_utc(bad)
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
@@ -8,7 +9,6 @@ from unittest.mock import patch
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.gmail.connector import _build_time_range_query
|
||||
from onyx.connectors.gmail.connector import GmailCheckpoint
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
@@ -51,29 +51,43 @@ def test_build_time_range_query() -> None:
|
||||
assert query is None
|
||||
|
||||
|
||||
def test_time_str_to_utc() -> None:
|
||||
str_to_dt = {
|
||||
"Tue, 5 Oct 2021 09:38:25 GMT": datetime.datetime(
|
||||
2021, 10, 5, 9, 38, 25, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"Sat, 24 Jul 2021 09:21:20 +0000 (UTC)": datetime.datetime(
|
||||
2021, 7, 24, 9, 21, 20, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"Thu, 29 Jul 2021 04:20:37 -0400 (EDT)": datetime.datetime(
|
||||
2021, 7, 29, 8, 20, 37, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"30 Jun 2023 18:45:01 +0300": datetime.datetime(
|
||||
2023, 6, 30, 15, 45, 1, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"22 Mar 2020 20:12:18 +0000 (GMT)": datetime.datetime(
|
||||
2020, 3, 22, 20, 12, 18, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"Date: Wed, 27 Aug 2025 11:40:00 +0200": datetime.datetime(
|
||||
2025, 8, 27, 9, 40, 0, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
for strptime, expected_datetime in str_to_dt.items():
|
||||
assert time_str_to_utc(strptime) == expected_datetime
|
||||
def _thread_with_date(date_header: str | None) -> dict[str, Any]:
|
||||
"""Load the fixture thread and replace (or strip, if None) its Date header."""
|
||||
json_path = os.path.join(os.path.dirname(__file__), "thread.json")
|
||||
with open(json_path, "r") as f:
|
||||
thread = cast(dict[str, Any], json.load(f))
|
||||
thread = copy.deepcopy(thread)
|
||||
|
||||
for message in thread["messages"]:
|
||||
headers: list[dict[str, str]] = message["payload"]["headers"]
|
||||
if date_header is None:
|
||||
message["payload"]["headers"] = [
|
||||
h for h in headers if h.get("name") != "Date"
|
||||
]
|
||||
continue
|
||||
|
||||
replaced = False
|
||||
for header in headers:
|
||||
if header.get("name") == "Date":
|
||||
header["value"] = date_header
|
||||
replaced = True
|
||||
break
|
||||
if not replaced:
|
||||
headers.append({"name": "Date", "value": date_header})
|
||||
|
||||
return thread
|
||||
|
||||
|
||||
def test_thread_to_document_skips_unparseable_dates() -> None:
|
||||
for bad_date in (
|
||||
"Wed, 33 Sep 2007 13:42:59 +0100",
|
||||
"Thu, 11 Oct 2007 31:50:55 +0900",
|
||||
"total garbage not even close to a date",
|
||||
):
|
||||
doc = thread_to_document(_thread_with_date(bad_date), "admin@example.com")
|
||||
assert isinstance(doc, Document), f"failed for {bad_date!r}"
|
||||
assert doc.doc_updated_at is None
|
||||
assert doc.id == "192edefb315737c3"
|
||||
|
||||
|
||||
def test_gmail_checkpoint_progression() -> None:
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from onyx.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from onyx.connectors.google_utils.google_kv import get_auth_url
|
||||
from onyx.connectors.google_utils.google_kv import get_google_app_cred
|
||||
from onyx.connectors.google_utils.google_kv import get_service_account_key
|
||||
from onyx.connectors.google_utils.google_kv import upsert_google_app_cred
|
||||
from onyx.connectors.google_utils.google_kv import upsert_service_account_key
|
||||
from onyx.server.documents.models import GoogleAppCredentials
|
||||
from onyx.server.documents.models import GoogleAppWebCredentials
|
||||
from onyx.server.documents.models import GoogleServiceAccountKey
|
||||
|
||||
|
||||
def _make_app_creds() -> GoogleAppCredentials:
|
||||
return GoogleAppCredentials(
|
||||
web=GoogleAppWebCredentials(
|
||||
client_id="client-id.apps.googleusercontent.com",
|
||||
project_id="test-project",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
|
||||
client_secret="secret",
|
||||
redirect_uris=["https://example.com/callback"],
|
||||
javascript_origins=["https://example.com"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _make_service_account_key() -> GoogleServiceAccountKey:
|
||||
return GoogleServiceAccountKey(
|
||||
type="service_account",
|
||||
project_id="test-project",
|
||||
private_key_id="private-key-id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n",
|
||||
client_email="test@test-project.iam.gserviceaccount.com",
|
||||
client_id="123",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
|
||||
client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test",
|
||||
universe_domain="googleapis.com",
|
||||
)
|
||||
|
||||
|
||||
def test_upsert_google_app_cred_stores_dict(monkeypatch: Any) -> None:
|
||||
stored: dict[str, Any] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored["key"] = key
|
||||
stored["value"] = value
|
||||
stored["encrypt"] = encrypt
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
upsert_google_app_cred(_make_app_creds(), DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert stored["key"] == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
assert stored["encrypt"] is True
|
||||
assert isinstance(stored["value"], dict)
|
||||
assert stored["value"]["web"]["client_id"] == "client-id.apps.googleusercontent.com"
|
||||
|
||||
|
||||
def test_upsert_service_account_key_stores_dict(monkeypatch: Any) -> None:
|
||||
stored: dict[str, Any] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored["key"] = key
|
||||
stored["value"] = value
|
||||
stored["encrypt"] = encrypt
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
upsert_service_account_key(_make_service_account_key(), DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert stored["key"] == KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
assert stored["encrypt"] is True
|
||||
assert isinstance(stored["value"], dict)
|
||||
assert stored["value"]["project_id"] == "test-project"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_google_app_cred_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
payload: dict[str, Any] = _make_app_creds().model_dump(mode="json")
|
||||
stored_value: object = (
|
||||
payload if not legacy_string else _make_app_creds().model_dump_json()
|
||||
)
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
return stored_value
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
creds = get_google_app_cred(DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert creds.web.client_id == "client-id.apps.googleusercontent.com"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_service_account_key_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
stored_value: object = (
|
||||
_make_service_account_key().model_dump(mode="json")
|
||||
if not legacy_string
|
||||
else _make_service_account_key().model_dump_json()
|
||||
)
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
return stored_value
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
key = get_service_account_key(DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert key.client_email == "test@test-project.iam.gserviceaccount.com"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_auth_url_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
payload = _make_app_creds().model_dump(mode="json")
|
||||
stored_value: object = (
|
||||
payload if not legacy_string else _make_app_creds().model_dump_json()
|
||||
)
|
||||
stored_state: dict[str, object] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
return stored_value
|
||||
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored_state["key"] = key
|
||||
stored_state["value"] = value
|
||||
stored_state["encrypt"] = encrypt
|
||||
|
||||
class _StubFlow:
|
||||
def authorization_url(self, prompt: str) -> tuple[str, None]:
|
||||
assert prompt == "consent"
|
||||
return "https://accounts.google.com/o/oauth2/auth?state=test-state", None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
def _from_client_config(
|
||||
_app_config: object, *, scopes: object, redirect_uri: object
|
||||
) -> _StubFlow:
|
||||
del scopes, redirect_uri
|
||||
return _StubFlow()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.InstalledAppFlow.from_client_config",
|
||||
_from_client_config,
|
||||
)
|
||||
|
||||
auth_url = get_auth_url(42, DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert auth_url.startswith("https://accounts.google.com")
|
||||
assert stored_state["value"] == {"value": "test-state"}
|
||||
assert stored_state["encrypt"] is True
|
||||
@@ -6,6 +6,7 @@ import requests
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
|
||||
from onyx.connectors.jira.connector import _JIRA_BULK_FETCH_LIMIT
|
||||
from onyx.connectors.jira.connector import bulk_fetch_issues
|
||||
|
||||
|
||||
@@ -145,3 +146,29 @@ def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])
|
||||
|
||||
|
||||
def test_bulk_fetch_respects_api_batch_limit() -> None:
|
||||
"""Requests to the bulkfetch endpoint never exceed _JIRA_BULK_FETCH_LIMIT IDs."""
|
||||
client = _mock_jira_client()
|
||||
total_issues = _JIRA_BULK_FETCH_LIMIT * 3 + 7
|
||||
all_ids = [str(i) for i in range(total_issues)]
|
||||
|
||||
batch_sizes: list[int] = []
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
batch_sizes.append(len(ids))
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
result = bulk_fetch_issues(client, all_ids)
|
||||
|
||||
assert len(result) == total_issues
|
||||
# keeping this hardcoded because it's the documented limit
|
||||
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
|
||||
assert all(size <= 100 for size in batch_sizes)
|
||||
assert len(batch_sizes) == 4
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Tests for _build_thread_text function."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.context.search.federated.slack_search import _build_thread_text
|
||||
|
||||
|
||||
def _make_msg(user: str, text: str, ts: str) -> dict[str, str]:
|
||||
return {"user": user, "text": text, "ts": ts}
|
||||
|
||||
|
||||
class TestBuildThreadText:
|
||||
"""Verify _build_thread_text includes full thread replies up to cap."""
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_includes_all_replies(self, mock_profiles: MagicMock) -> None:
|
||||
"""All replies within cap are included in output."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [
|
||||
_make_msg("U1", "parent msg", "1000.0"),
|
||||
_make_msg("U2", "reply 1", "1001.0"),
|
||||
_make_msg("U3", "reply 2", "1002.0"),
|
||||
_make_msg("U4", "reply 3", "1003.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "parent msg" in result
|
||||
assert "reply 1" in result
|
||||
assert "reply 2" in result
|
||||
assert "reply 3" in result
|
||||
assert "..." not in result
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_non_thread_returns_parent_only(self, mock_profiles: MagicMock) -> None:
|
||||
"""Single message (no replies) returns just the parent text."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [_make_msg("U1", "just a message", "1000.0")]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "just a message" in result
|
||||
assert "Replies:" not in result
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_parent_always_first(self, mock_profiles: MagicMock) -> None:
|
||||
"""Thread parent message is always the first line of output."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [
|
||||
_make_msg("U1", "I am the parent", "1000.0"),
|
||||
_make_msg("U2", "I am a reply", "1001.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
parent_pos = result.index("I am the parent")
|
||||
reply_pos = result.index("I am a reply")
|
||||
assert parent_pos < reply_pos
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_user_profiles_resolved(self, mock_profiles: MagicMock) -> None:
|
||||
"""User IDs in thread text are replaced with display names."""
|
||||
mock_profiles.return_value = {"U1": "Alice", "U2": "Bob"}
|
||||
messages = [
|
||||
_make_msg("U1", "hello", "1000.0"),
|
||||
_make_msg("U2", "world", "1001.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "Alice" in result
|
||||
assert "Bob" in result
|
||||
assert "<@U1>" not in result
|
||||
assert "<@U2>" not in result
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Tests for Slack URL parsing and direct thread fetch via URL override."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.slack_search import _fetch_thread_from_url
|
||||
from onyx.context.search.federated.slack_search_utils import extract_slack_message_urls
|
||||
|
||||
|
||||
class TestExtractSlackMessageUrls:
|
||||
"""Verify URL parsing extracts channel_id and timestamp correctly."""
|
||||
|
||||
def test_standard_url(self) -> None:
|
||||
query = "summarize https://mycompany.slack.com/archives/C097NBWMY8Y/p1775491616524769"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 1
|
||||
assert results[0] == ("C097NBWMY8Y", "1775491616.524769")
|
||||
|
||||
def test_multiple_urls(self) -> None:
|
||||
query = (
|
||||
"compare https://co.slack.com/archives/C111/p1234567890123456 "
|
||||
"and https://co.slack.com/archives/C222/p9876543210987654"
|
||||
)
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 2
|
||||
assert results[0] == ("C111", "1234567890.123456")
|
||||
assert results[1] == ("C222", "9876543210.987654")
|
||||
|
||||
def test_no_urls(self) -> None:
|
||||
query = "what happened in #general last week?"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_non_slack_url_ignored(self) -> None:
|
||||
query = "check https://google.com/archives/C111/p1234567890123456"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_timestamp_conversion(self) -> None:
|
||||
"""p prefix removed, dot inserted after 10th digit."""
|
||||
query = "https://x.slack.com/archives/CABC123/p1775491616524769"
|
||||
results = extract_slack_message_urls(query)
|
||||
channel_id, ts = results[0]
|
||||
assert channel_id == "CABC123"
|
||||
assert ts == "1775491616.524769"
|
||||
assert not ts.startswith("p")
|
||||
assert "." in ts
|
||||
|
||||
|
||||
class TestFetchThreadFromUrl:
|
||||
"""Verify _fetch_thread_from_url calls conversations.replies and returns SlackMessage."""
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search._build_thread_text")
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_successful_fetch(
|
||||
self, mock_webclient_cls: MagicMock, mock_build_thread: MagicMock
|
||||
) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_webclient_cls.return_value = mock_client
|
||||
|
||||
# Mock conversations_replies
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = [
|
||||
{"user": "U1", "text": "parent", "ts": "1775491616.524769"},
|
||||
{"user": "U2", "text": "reply 1", "ts": "1775491617.000000"},
|
||||
{"user": "U3", "text": "reply 2", "ts": "1775491618.000000"},
|
||||
]
|
||||
mock_client.conversations_replies.return_value = mock_response
|
||||
|
||||
# Mock channel info
|
||||
mock_ch_response = MagicMock()
|
||||
mock_ch_response.get.return_value = {"name": "general"}
|
||||
mock_client.conversations_info.return_value = mock_ch_response
|
||||
|
||||
mock_build_thread.return_value = (
|
||||
"U1: parent\n\nReplies:\n\nU2: reply 1\n\nU3: reply 2"
|
||||
)
|
||||
|
||||
fetch = DirectThreadFetch(
|
||||
channel_id="C097NBWMY8Y", thread_ts="1775491616.524769"
|
||||
)
|
||||
result = _fetch_thread_from_url(fetch, "xoxp-token")
|
||||
|
||||
assert len(result.messages) == 1
|
||||
msg = result.messages[0]
|
||||
assert msg.channel_id == "C097NBWMY8Y"
|
||||
assert msg.thread_id is None # Prevents double-enrichment
|
||||
assert msg.slack_score == 100000.0
|
||||
assert "parent" in msg.text
|
||||
mock_client.conversations_replies.assert_called_once_with(
|
||||
channel="C097NBWMY8Y", ts="1775491616.524769"
|
||||
)
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_api_error_returns_empty(self, mock_webclient_cls: MagicMock) -> None:
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_webclient_cls.return_value = mock_client
|
||||
mock_client.conversations_replies.side_effect = SlackApiError(
|
||||
message="channel_not_found",
|
||||
response=MagicMock(status_code=404),
|
||||
)
|
||||
|
||||
fetch = DirectThreadFetch(channel_id="CBAD", thread_ts="1234567890.123456")
|
||||
result = _fetch_thread_from_url(fetch, "xoxp-token")
|
||||
assert len(result.messages) == 0
|
||||
@@ -12,6 +12,10 @@ dependency on pypdf internals (pypdf.generic).
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.file_processing import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.file_processing.extract_file_text import pdf_to_text
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.password_validation import is_pdf_protected
|
||||
@@ -96,6 +100,80 @@ class TestReadPdfFile:
|
||||
# Returned list is empty when callback is used
|
||||
assert images == []
|
||||
|
||||
def test_image_cap_skips_images_above_limit(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""When the embedded-image cap is exceeded, remaining images are skipped.
|
||||
|
||||
The cap protects the user-file-processing worker from OOMing on PDFs
|
||||
with thousands of embedded images. Setting the cap to 0 should yield
|
||||
zero extracted images even though the fixture has one.
|
||||
"""
|
||||
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 0)
|
||||
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
|
||||
assert images == []
|
||||
|
||||
def test_image_cap_at_limit_extracts_up_to_cap(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A cap >= image count behaves identically to the uncapped path."""
|
||||
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 100)
|
||||
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
|
||||
assert len(images) == 1
|
||||
|
||||
def test_image_cap_with_callback_stops_streaming_at_limit(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""The cap also short-circuits the streaming callback path."""
|
||||
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 0)
|
||||
collected: list[tuple[bytes, str]] = []
|
||||
|
||||
def callback(data: bytes, name: str) -> None:
|
||||
collected.append((data, name))
|
||||
|
||||
read_pdf_file(
|
||||
_load("with_image.pdf"), extract_images=True, image_callback=callback
|
||||
)
|
||||
assert collected == []
|
||||
|
||||
|
||||
# ── count_pdf_embedded_images ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCountPdfEmbeddedImages:
|
||||
def test_returns_count_for_normal_pdf(self) -> None:
|
||||
assert count_pdf_embedded_images(_load("with_image.pdf"), cap=10) == 1
|
||||
|
||||
def test_short_circuits_above_cap(self) -> None:
|
||||
# with_image.pdf has 1 image. cap=0 means "anything > 0 is over cap" —
|
||||
# function returns on first increment as the over-cap sentinel.
|
||||
assert count_pdf_embedded_images(_load("with_image.pdf"), cap=0) == 1
|
||||
|
||||
def test_returns_zero_for_pdf_without_images(self) -> None:
|
||||
assert count_pdf_embedded_images(_load("simple.pdf"), cap=10) == 0
|
||||
|
||||
def test_returns_zero_for_invalid_pdf(self) -> None:
|
||||
assert count_pdf_embedded_images(BytesIO(b"not a pdf"), cap=10) == 0
|
||||
|
||||
def test_returns_zero_for_password_locked_pdf(self) -> None:
|
||||
# encrypted.pdf has an open password; we can't inspect without it, so
|
||||
# the helper returns 0 — callers rely on the password-protected check
|
||||
# that runs earlier in the upload pipeline.
|
||||
assert count_pdf_embedded_images(_load("encrypted.pdf"), cap=10) == 0
|
||||
|
||||
def test_inspects_owner_password_only_pdf(self) -> None:
|
||||
# owner_protected.pdf is encrypted but has no open password. It should
|
||||
# decrypt with an empty string and count images normally. The fixture
|
||||
# has zero images, so 0 is a real count (not the "bail on encrypted"
|
||||
# path).
|
||||
assert count_pdf_embedded_images(_load("owner_protected.pdf"), cap=10) == 0
|
||||
|
||||
def test_preserves_file_position(self) -> None:
|
||||
pdf = _load("with_image.pdf")
|
||||
pdf.seek(42)
|
||||
count_pdf_embedded_images(pdf, cap=10)
|
||||
assert pdf.tell() == 42
|
||||
|
||||
|
||||
# ── pdf_to_text ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.llm.utils import get_max_input_tokens
|
||||
VERTEX_OPUS_MODELS_REJECTING_OUTPUT_CONFIG = [
|
||||
"claude-opus-4-5@20251101",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-7",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -505,6 +505,7 @@ class TestGetLMStudioAvailableModels:
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.api_base = "http://localhost:1234"
|
||||
mock_provider.custom_config = {"LM_STUDIO_API_KEY": "stored-secret"}
|
||||
|
||||
response = {
|
||||
|
||||
@@ -100,6 +100,39 @@ class TestGenerateOllamaDisplayName:
|
||||
result = generate_ollama_display_name("llama3.3:70b")
|
||||
assert "3.3" in result or "3 3" in result # Either format is acceptable
|
||||
|
||||
def test_non_size_tag_shown(self) -> None:
|
||||
"""Test that non-size tags like 'e4b' are included in the display name."""
|
||||
result = generate_ollama_display_name("gemma4:e4b")
|
||||
assert "Gemma" in result
|
||||
assert "4" in result
|
||||
assert "E4B" in result
|
||||
|
||||
def test_size_with_cloud_modifier(self) -> None:
|
||||
"""Test size tag with cloud modifier."""
|
||||
result = generate_ollama_display_name("deepseek-v3.1:671b-cloud")
|
||||
assert "DeepSeek" in result
|
||||
assert "671B" in result
|
||||
assert "Cloud" in result
|
||||
|
||||
def test_size_with_multiple_modifiers(self) -> None:
|
||||
"""Test size tag with multiple modifiers."""
|
||||
result = generate_ollama_display_name("qwen3-vl:235b-instruct-cloud")
|
||||
assert "Qwen" in result
|
||||
assert "235B" in result
|
||||
assert "Instruct" in result
|
||||
assert "Cloud" in result
|
||||
|
||||
def test_quantization_tag_shown(self) -> None:
|
||||
"""Test that quantization tags are included in the display name."""
|
||||
result = generate_ollama_display_name("llama3:q4_0")
|
||||
assert "Llama" in result
|
||||
assert "Q4_0" in result
|
||||
|
||||
def test_cloud_only_tag(self) -> None:
|
||||
"""Test standalone cloud tag."""
|
||||
result = generate_ollama_display_name("glm-4.6:cloud")
|
||||
assert "CLOUD" in result
|
||||
|
||||
|
||||
class TestStripOpenrouterVendorPrefix:
|
||||
"""Tests for OpenRouter vendor prefix stripping."""
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -9,7 +10,9 @@ from uuid import uuid4
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.server.scim.api import _check_seat_availability
|
||||
from ee.onyx.server.scim.api import _scim_name_to_str
|
||||
from ee.onyx.server.scim.api import _seat_lock_id_for_tenant
|
||||
from ee.onyx.server.scim.api import create_user
|
||||
from ee.onyx.server.scim.api import delete_user
|
||||
from ee.onyx.server.scim.api import get_user
|
||||
@@ -741,3 +744,80 @@ class TestEmailCasePreservation:
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.userName == "Alice@Example.COM"
|
||||
assert resource.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
|
||||
class TestSeatLock:
|
||||
"""Tests for the advisory lock in _check_seat_availability."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_abc")
|
||||
def test_acquires_advisory_lock_before_checking(
|
||||
self,
|
||||
_mock_tenant: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""The advisory lock must be acquired before the seat check runs."""
|
||||
call_order: list[str] = []
|
||||
|
||||
def track_execute(stmt: Any, _params: Any = None) -> None:
|
||||
if "pg_advisory_xact_lock" in str(stmt):
|
||||
call_order.append("lock")
|
||||
|
||||
mock_dal.session.execute.side_effect = track_execute
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop"
|
||||
) as mock_fetch:
|
||||
mock_result = MagicMock()
|
||||
mock_result.available = True
|
||||
mock_fn = MagicMock(return_value=mock_result)
|
||||
mock_fetch.return_value = mock_fn
|
||||
|
||||
def track_check(*_args: Any, **_kwargs: Any) -> Any:
|
||||
call_order.append("check")
|
||||
return mock_result
|
||||
|
||||
mock_fn.side_effect = track_check
|
||||
|
||||
_check_seat_availability(mock_dal)
|
||||
|
||||
assert call_order == ["lock", "check"]
|
||||
|
||||
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_xyz")
|
||||
def test_lock_uses_tenant_scoped_key(
|
||||
self,
|
||||
_mock_tenant: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""The lock id must be derived from the tenant via _seat_lock_id_for_tenant."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.available = True
|
||||
mock_check = MagicMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
|
||||
return_value=mock_check,
|
||||
):
|
||||
_check_seat_availability(mock_dal)
|
||||
|
||||
mock_dal.session.execute.assert_called_once()
|
||||
params = mock_dal.session.execute.call_args[0][1]
|
||||
assert params["lock_id"] == _seat_lock_id_for_tenant("tenant_xyz")
|
||||
|
||||
def test_seat_lock_id_is_stable_and_tenant_scoped(self) -> None:
|
||||
"""Lock id must be deterministic and differ across tenants."""
|
||||
assert _seat_lock_id_for_tenant("t1") == _seat_lock_id_for_tenant("t1")
|
||||
assert _seat_lock_id_for_tenant("t1") != _seat_lock_id_for_tenant("t2")
|
||||
|
||||
def test_no_lock_when_ee_absent(
|
||||
self,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""No advisory lock should be acquired when the EE check is absent."""
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
|
||||
return_value=None,
|
||||
):
|
||||
result = _check_seat_availability(mock_dal)
|
||||
|
||||
assert result is None
|
||||
mock_dal.session.execute.assert_not_called()
|
||||
|
||||
@@ -95,9 +95,9 @@ class TestForceAddSearchToolGuard:
|
||||
without a vector DB."""
|
||||
import inspect
|
||||
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import _construct_tools_impl
|
||||
|
||||
source = inspect.getsource(construct_tools)
|
||||
source = inspect.getsource(_construct_tools_impl)
|
||||
assert (
|
||||
"DISABLE_VECTOR_DB" in source
|
||||
), "construct_tools should reference DISABLE_VECTOR_DB to suppress force-adding SearchTool"
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
"""Tests for ``ImageGenerationTool._resolve_reference_image_file_ids``.
|
||||
|
||||
The resolver turns the LLM's ``reference_image_file_ids`` argument into a
|
||||
cleaned list of file IDs to hand to ``_load_reference_images``. It trusts
|
||||
the LLM's picks — the LLM can only see file IDs that actually appear in
|
||||
the conversation (via ``[attached image — file_id: <id>]`` tags on user
|
||||
messages and the JSON returned by prior generate_image calls), so we
|
||||
don't re-validate against an allow-list in the tool itself.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD,
|
||||
)
|
||||
|
||||
|
||||
def _make_tool(
|
||||
supports_reference_images: bool = True,
|
||||
max_reference_images: int = 16,
|
||||
) -> ImageGenerationTool:
|
||||
"""Construct a tool with a mock provider so no credentials/network are needed."""
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.images.image_generation_tool.get_image_generation_provider"
|
||||
) as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.supports_reference_images = supports_reference_images
|
||||
mock_provider.max_reference_images = max_reference_images
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
return ImageGenerationTool(
|
||||
image_generation_credentials=MagicMock(),
|
||||
tool_id=1,
|
||||
emitter=MagicMock(),
|
||||
model="gpt-image-1",
|
||||
provider="openai",
|
||||
)
|
||||
|
||||
|
||||
class TestResolveReferenceImageFileIds:
|
||||
def test_unset_returns_empty_plain_generation(self) -> None:
|
||||
tool = _make_tool()
|
||||
assert tool._resolve_reference_image_file_ids(llm_kwargs={}) == []
|
||||
|
||||
def test_empty_list_is_treated_like_unset(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: []},
|
||||
)
|
||||
assert result == []
|
||||
|
||||
def test_passes_llm_supplied_ids_through(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["upload-1", "gen-1"]},
|
||||
)
|
||||
# Order preserved — first entry is the primary edit source.
|
||||
assert result == ["upload-1", "gen-1"]
|
||||
|
||||
def test_invalid_shape_raises(self) -> None:
|
||||
tool = _make_tool()
|
||||
with pytest.raises(ToolCallException):
|
||||
tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: "not-a-list"},
|
||||
)
|
||||
|
||||
def test_non_string_element_raises(self) -> None:
|
||||
tool = _make_tool()
|
||||
with pytest.raises(ToolCallException):
|
||||
tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["ok", 123]},
|
||||
)
|
||||
|
||||
def test_deduplicates_preserving_first_occurrence(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["gen-1", "gen-2", "gen-1"]},
|
||||
)
|
||||
assert result == ["gen-1", "gen-2"]
|
||||
|
||||
def test_strips_whitespace_and_skips_empty_strings(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: [" gen-1 ", "", " "]},
|
||||
)
|
||||
assert result == ["gen-1"]
|
||||
|
||||
def test_provider_without_reference_support_raises(self) -> None:
|
||||
tool = _make_tool(supports_reference_images=False)
|
||||
with pytest.raises(ToolCallException):
|
||||
tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["gen-1"]},
|
||||
)
|
||||
|
||||
def test_truncates_to_provider_max_preserving_head(self) -> None:
|
||||
"""When the LLM lists more images than the provider allows, keep the
|
||||
HEAD of the list (the primary edit source + earliest extras) rather
|
||||
than the tail, since the LLM put the most important one first."""
|
||||
tool = _make_tool(max_reference_images=2)
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["a", "b", "c", "d"]},
|
||||
)
|
||||
assert result == ["a", "b"]
|
||||
@@ -1,10 +1,5 @@
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_runner import _extract_image_file_ids_from_tool_response_message
|
||||
from onyx.tools.tool_runner import _extract_recent_generated_image_file_ids
|
||||
from onyx.tools.tool_runner import _merge_tool_calls
|
||||
|
||||
|
||||
@@ -312,62 +307,3 @@ class TestMergeToolCalls:
|
||||
assert len(result) == 1
|
||||
# String should be converted to list item
|
||||
assert result[0].tool_args["queries"] == ["single_query", "q2"]
|
||||
|
||||
|
||||
class TestImageHistoryExtraction:
|
||||
def test_extracts_image_file_ids_from_json_response(self) -> None:
|
||||
msg = '[{"file_id":"img-1","revised_prompt":"v1"},{"file_id":"img-2","revised_prompt":"v2"}]'
|
||||
assert _extract_image_file_ids_from_tool_response_message(msg) == [
|
||||
"img-1",
|
||||
"img-2",
|
||||
]
|
||||
|
||||
def test_extracts_recent_generated_image_ids_from_history(self) -> None:
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="",
|
||||
token_count=1,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
tool_calls=[
|
||||
ToolCallSimple(
|
||||
tool_call_id="call_1",
|
||||
tool_name="generate_image",
|
||||
tool_arguments={"prompt": "test"},
|
||||
token_count=1,
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatMessageSimple(
|
||||
message='[{"file_id":"img-1","revised_prompt":"r1"}]',
|
||||
token_count=1,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id="call_1",
|
||||
),
|
||||
]
|
||||
|
||||
assert _extract_recent_generated_image_file_ids(history) == ["img-1"]
|
||||
|
||||
def test_ignores_non_image_tool_responses(self) -> None:
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="",
|
||||
token_count=1,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
tool_calls=[
|
||||
ToolCallSimple(
|
||||
tool_call_id="call_1",
|
||||
tool_name="web_search",
|
||||
tool_arguments={"queries": ["q"]},
|
||||
token_count=1,
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatMessageSimple(
|
||||
message='[{"file_id":"img-1","revised_prompt":"r1"}]',
|
||||
token_count=1,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id="call_1",
|
||||
),
|
||||
]
|
||||
|
||||
assert _extract_recent_generated_image_file_ids(history) == []
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
"""Tests for generic Celery task lifecycle Prometheus metrics."""
|
||||
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.background.celery.apps.app_base import on_before_task_publish
|
||||
from onyx.server.metrics.celery_task_metrics import _task_start_times
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_COMPLETED
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_DURATION
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_QUEUE_WAIT
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_STARTED
|
||||
from onyx.server.metrics.celery_task_metrics import TASKS_ACTIVE
|
||||
|
||||
@@ -22,11 +25,18 @@ def reset_metrics() -> Iterator[None]:
|
||||
_task_start_times.clear()
|
||||
|
||||
|
||||
def _make_task(name: str = "test_task", queue: str = "test_queue") -> MagicMock:
|
||||
def _make_task(
|
||||
name: str = "test_task",
|
||||
queue: str = "test_queue",
|
||||
enqueued_at: float | None = None,
|
||||
) -> MagicMock:
|
||||
task = MagicMock()
|
||||
task.name = name
|
||||
task.request = MagicMock()
|
||||
task.request.delivery_info = {"routing_key": queue}
|
||||
task.request.headers = (
|
||||
{"enqueued_at": enqueued_at} if enqueued_at is not None else {}
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
@@ -72,6 +82,35 @@ class TestCeleryTaskPrerun:
|
||||
on_celery_task_prerun("task-1", task)
|
||||
assert "task-1" in _task_start_times
|
||||
|
||||
def test_observes_queue_wait_when_enqueued_at_present(self) -> None:
|
||||
enqueued_at = time.time() - 30 # simulates 30s wait
|
||||
task = _make_task(enqueued_at=enqueued_at)
|
||||
|
||||
before = TASK_QUEUE_WAIT.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._sum.get()
|
||||
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
after = TASK_QUEUE_WAIT.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._sum.get()
|
||||
assert after >= before + 30
|
||||
|
||||
def test_skips_queue_wait_when_enqueued_at_missing(self) -> None:
|
||||
task = _make_task() # no enqueued_at in headers
|
||||
|
||||
before = TASK_QUEUE_WAIT.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._sum.get()
|
||||
|
||||
on_celery_task_prerun("task-2", task)
|
||||
|
||||
after = TASK_QUEUE_WAIT.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._sum.get()
|
||||
assert after == before
|
||||
|
||||
|
||||
class TestCeleryTaskPostrun:
|
||||
def test_increments_completed_success(self) -> None:
|
||||
@@ -151,3 +190,15 @@ class TestCeleryTaskPostrun:
|
||||
task = _make_task()
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
# Should not raise
|
||||
|
||||
|
||||
class TestBeforeTaskPublish:
|
||||
def test_stamps_enqueued_at_into_headers(self) -> None:
|
||||
before = time.time()
|
||||
headers: dict = {}
|
||||
on_before_task_publish(headers=headers)
|
||||
assert "enqueued_at" in headers
|
||||
assert headers["enqueued_at"] >= before
|
||||
|
||||
def test_noop_when_headers_is_none(self) -> None:
|
||||
on_before_task_publish(headers=None) # should not raise
|
||||
|
||||
204
backend/tests/unit/server/metrics/test_deletion_metrics.py
Normal file
204
backend/tests/unit/server/metrics/test_deletion_metrics.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""Tests for deletion-specific Prometheus metrics."""
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.deletion_metrics import DELETION_BLOCKED
|
||||
from onyx.server.metrics.deletion_metrics import DELETION_COMPLETED
|
||||
from onyx.server.metrics.deletion_metrics import DELETION_FENCE_RESET
|
||||
from onyx.server.metrics.deletion_metrics import DELETION_STARTED
|
||||
from onyx.server.metrics.deletion_metrics import DELETION_TASKSET_DURATION
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_blocked
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_completed
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_fence_reset
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_started
|
||||
from onyx.server.metrics.deletion_metrics import observe_deletion_taskset_duration
|
||||
|
||||
|
||||
class TestIncDeletionStarted:
|
||||
def test_increments_counter(self) -> None:
|
||||
before = DELETION_STARTED.labels(tenant_id="t1")._value.get()
|
||||
|
||||
inc_deletion_started("t1")
|
||||
|
||||
after = DELETION_STARTED.labels(tenant_id="t1")._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_labels_by_tenant(self) -> None:
|
||||
before_t1 = DELETION_STARTED.labels(tenant_id="t1")._value.get()
|
||||
before_t2 = DELETION_STARTED.labels(tenant_id="t2")._value.get()
|
||||
|
||||
inc_deletion_started("t1")
|
||||
|
||||
assert DELETION_STARTED.labels(tenant_id="t1")._value.get() == before_t1 + 1
|
||||
assert DELETION_STARTED.labels(tenant_id="t2")._value.get() == before_t2
|
||||
|
||||
def test_does_not_raise_on_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
DELETION_STARTED,
|
||||
"labels",
|
||||
lambda **_: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
inc_deletion_started("t1")
|
||||
|
||||
|
||||
class TestIncDeletionCompleted:
|
||||
def test_increments_counter(self) -> None:
|
||||
before = DELETION_COMPLETED.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._value.get()
|
||||
|
||||
inc_deletion_completed("t1", "success")
|
||||
|
||||
after = DELETION_COMPLETED.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_labels_by_outcome(self) -> None:
|
||||
before_success = DELETION_COMPLETED.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._value.get()
|
||||
before_failure = DELETION_COMPLETED.labels(
|
||||
tenant_id="t1", outcome="failure"
|
||||
)._value.get()
|
||||
|
||||
inc_deletion_completed("t1", "success")
|
||||
|
||||
assert (
|
||||
DELETION_COMPLETED.labels(tenant_id="t1", outcome="success")._value.get()
|
||||
== before_success + 1
|
||||
)
|
||||
assert (
|
||||
DELETION_COMPLETED.labels(tenant_id="t1", outcome="failure")._value.get()
|
||||
== before_failure
|
||||
)
|
||||
|
||||
def test_does_not_raise_on_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
DELETION_COMPLETED,
|
||||
"labels",
|
||||
lambda **_: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
inc_deletion_completed("t1", "success")
|
||||
|
||||
|
||||
class TestObserveDeletionTasksetDuration:
|
||||
def test_observes_duration(self) -> None:
|
||||
before = DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._sum.get()
|
||||
|
||||
observe_deletion_taskset_duration("t1", "success", 120.0)
|
||||
|
||||
after = DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._sum.get()
|
||||
assert after == pytest.approx(before + 120.0)
|
||||
|
||||
def test_labels_by_tenant(self) -> None:
|
||||
before_t1 = DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._sum.get()
|
||||
before_t2 = DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t2", outcome="success"
|
||||
)._sum.get()
|
||||
|
||||
observe_deletion_taskset_duration("t1", "success", 60.0)
|
||||
|
||||
assert DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._sum.get() == pytest.approx(before_t1 + 60.0)
|
||||
assert DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t2", outcome="success"
|
||||
)._sum.get() == pytest.approx(before_t2)
|
||||
|
||||
def test_labels_by_outcome(self) -> None:
|
||||
before_success = DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._sum.get()
|
||||
before_failure = DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="failure"
|
||||
)._sum.get()
|
||||
|
||||
observe_deletion_taskset_duration("t1", "failure", 45.0)
|
||||
|
||||
assert DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._sum.get() == pytest.approx(before_success)
|
||||
assert DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="failure"
|
||||
)._sum.get() == pytest.approx(before_failure + 45.0)
|
||||
|
||||
def test_does_not_raise_on_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
DELETION_TASKSET_DURATION,
|
||||
"labels",
|
||||
lambda **_: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
observe_deletion_taskset_duration("t1", "success", 10.0)
|
||||
|
||||
|
||||
class TestIncDeletionBlocked:
|
||||
def test_increments_counter(self) -> None:
|
||||
before = DELETION_BLOCKED.labels(
|
||||
tenant_id="t1", blocker="indexing"
|
||||
)._value.get()
|
||||
|
||||
inc_deletion_blocked("t1", "indexing")
|
||||
|
||||
after = DELETION_BLOCKED.labels(tenant_id="t1", blocker="indexing")._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_labels_by_blocker(self) -> None:
|
||||
before_idx = DELETION_BLOCKED.labels(
|
||||
tenant_id="t1", blocker="indexing"
|
||||
)._value.get()
|
||||
before_prune = DELETION_BLOCKED.labels(
|
||||
tenant_id="t1", blocker="pruning"
|
||||
)._value.get()
|
||||
|
||||
inc_deletion_blocked("t1", "indexing")
|
||||
|
||||
assert (
|
||||
DELETION_BLOCKED.labels(tenant_id="t1", blocker="indexing")._value.get()
|
||||
== before_idx + 1
|
||||
)
|
||||
assert (
|
||||
DELETION_BLOCKED.labels(tenant_id="t1", blocker="pruning")._value.get()
|
||||
== before_prune
|
||||
)
|
||||
|
||||
def test_does_not_raise_on_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
DELETION_BLOCKED,
|
||||
"labels",
|
||||
lambda **_: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
inc_deletion_blocked("t1", "indexing")
|
||||
|
||||
|
||||
class TestIncDeletionFenceReset:
|
||||
def test_increments_counter(self) -> None:
|
||||
before = DELETION_FENCE_RESET.labels(tenant_id="t1")._value.get()
|
||||
|
||||
inc_deletion_fence_reset("t1")
|
||||
|
||||
after = DELETION_FENCE_RESET.labels(tenant_id="t1")._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_labels_by_tenant(self) -> None:
|
||||
before_t1 = DELETION_FENCE_RESET.labels(tenant_id="t1")._value.get()
|
||||
before_t2 = DELETION_FENCE_RESET.labels(tenant_id="t2")._value.get()
|
||||
|
||||
inc_deletion_fence_reset("t1")
|
||||
|
||||
assert DELETION_FENCE_RESET.labels(tenant_id="t1")._value.get() == before_t1 + 1
|
||||
assert DELETION_FENCE_RESET.labels(tenant_id="t2")._value.get() == before_t2
|
||||
|
||||
def test_does_not_raise_on_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
DELETION_FENCE_RESET,
|
||||
"labels",
|
||||
lambda **_: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
inc_deletion_fence_reset("t1")
|
||||
@@ -1,16 +1,11 @@
|
||||
"""Tests for indexing pipeline Prometheus collectors."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
|
||||
|
||||
@@ -18,7 +13,7 @@ from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
def _mock_broker_client() -> Iterator[None]:
|
||||
"""Patch celery_get_broker_client for all collector tests."""
|
||||
with patch(
|
||||
"onyx.background.celery.celery_redis.celery_get_broker_client",
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
):
|
||||
yield
|
||||
@@ -137,212 +132,3 @@ class TestQueueDepthCollector:
|
||||
stale_result = collector.collect()
|
||||
|
||||
assert stale_result is good_result
|
||||
|
||||
|
||||
class TestIndexAttemptCollector:
|
||||
def test_returns_empty_when_not_configured(self) -> None:
|
||||
collector = IndexAttemptCollector()
|
||||
assert collector.collect() == []
|
||||
|
||||
def test_returns_empty_describe(self) -> None:
|
||||
collector = IndexAttemptCollector()
|
||||
assert collector.describe() == []
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_collects_index_attempts(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = IndexAttemptCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = ["public"]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
from onyx.db.enums import IndexingStatus
|
||||
|
||||
mock_row = (
|
||||
IndexingStatus.IN_PROGRESS,
|
||||
MagicMock(value="web"),
|
||||
81,
|
||||
"Table Tennis Blade Guide",
|
||||
2,
|
||||
)
|
||||
mock_session.query.return_value.join.return_value.join.return_value.filter.return_value.group_by.return_value.all.return_value = [
|
||||
mock_row
|
||||
]
|
||||
|
||||
families = collector.collect()
|
||||
assert len(families) == 1
|
||||
assert families[0].name == "onyx_index_attempts_active"
|
||||
assert len(families[0].samples) == 1
|
||||
sample = families[0].samples[0]
|
||||
assert sample.labels == {
|
||||
"status": "in_progress",
|
||||
"source": "web",
|
||||
"tenant_id": "public",
|
||||
"connector_name": "Table Tennis Blade Guide",
|
||||
"cc_pair_id": "81",
|
||||
}
|
||||
assert sample.value == 2
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
def test_handles_db_error_gracefully(
|
||||
self,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = IndexAttemptCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.side_effect = Exception("DB down")
|
||||
families = collector.collect()
|
||||
# No stale cache, so returns empty
|
||||
assert families == []
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
def test_skips_none_tenant_ids(
|
||||
self,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = IndexAttemptCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = [None]
|
||||
families = collector.collect()
|
||||
assert len(families) == 1 # Returns the gauge family, just with no samples
|
||||
assert len(families[0].samples) == 0
|
||||
|
||||
|
||||
class TestConnectorHealthCollector:
|
||||
def test_returns_empty_when_not_configured(self) -> None:
|
||||
collector = ConnectorHealthCollector()
|
||||
assert collector.collect() == []
|
||||
|
||||
def test_returns_empty_describe(self) -> None:
|
||||
collector = ConnectorHealthCollector()
|
||||
assert collector.describe() == []
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_collects_connector_health(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = ConnectorHealthCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = ["public"]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
last_success = now - timedelta(hours=2)
|
||||
|
||||
mock_status = MagicMock(value="ACTIVE")
|
||||
mock_source = MagicMock(value="google_drive")
|
||||
# Row: (id, status, in_error, last_success, name, source)
|
||||
mock_row = (
|
||||
42,
|
||||
mock_status,
|
||||
True, # in_repeated_error_state
|
||||
last_success,
|
||||
"My GDrive Connector",
|
||||
mock_source,
|
||||
)
|
||||
mock_session.query.return_value.join.return_value.all.return_value = [mock_row]
|
||||
|
||||
# Mock the index attempt queries (error counts + docs counts)
|
||||
mock_session.query.return_value.filter.return_value.group_by.return_value.all.return_value = (
|
||||
[]
|
||||
)
|
||||
|
||||
families = collector.collect()
|
||||
|
||||
assert len(families) == 6
|
||||
names = {f.name for f in families}
|
||||
assert names == {
|
||||
"onyx_connector_last_success_age_seconds",
|
||||
"onyx_connector_in_error_state",
|
||||
"onyx_connectors_by_status",
|
||||
"onyx_connectors_in_error_total",
|
||||
"onyx_connector_docs_indexed",
|
||||
"onyx_connector_error_count",
|
||||
}
|
||||
|
||||
staleness = next(
|
||||
f for f in families if f.name == "onyx_connector_last_success_age_seconds"
|
||||
)
|
||||
assert len(staleness.samples) == 1
|
||||
assert staleness.samples[0].value == pytest.approx(7200, abs=5)
|
||||
|
||||
error_state = next(
|
||||
f for f in families if f.name == "onyx_connector_in_error_state"
|
||||
)
|
||||
assert error_state.samples[0].value == 1.0
|
||||
|
||||
by_status = next(f for f in families if f.name == "onyx_connectors_by_status")
|
||||
assert by_status.samples[0].labels == {
|
||||
"tenant_id": "public",
|
||||
"status": "ACTIVE",
|
||||
}
|
||||
assert by_status.samples[0].value == 1
|
||||
|
||||
error_total = next(
|
||||
f for f in families if f.name == "onyx_connectors_in_error_total"
|
||||
)
|
||||
assert error_total.samples[0].value == 1
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_skips_staleness_when_no_last_success(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = ConnectorHealthCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = ["public"]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_status = MagicMock(value="INITIAL_INDEXING")
|
||||
mock_source = MagicMock(value="slack")
|
||||
mock_row = (
|
||||
10,
|
||||
mock_status,
|
||||
False,
|
||||
None, # no last_successful_index_time
|
||||
0,
|
||||
mock_source,
|
||||
)
|
||||
mock_session.query.return_value.join.return_value.all.return_value = [mock_row]
|
||||
|
||||
families = collector.collect()
|
||||
|
||||
staleness = next(
|
||||
f for f in families if f.name == "onyx_connector_last_success_age_seconds"
|
||||
)
|
||||
assert len(staleness.samples) == 0
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
def test_handles_db_error_gracefully(
|
||||
self,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = ConnectorHealthCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.side_effect = Exception("DB down")
|
||||
families = collector.collect()
|
||||
assert families == []
|
||||
|
||||
@@ -217,11 +217,23 @@ Enriches docfetching and docprocessing tasks with connector-level labels. Silent
|
||||
| `onyx_indexing_task_completed_total` | Counter | `task_name`, `source`, `tenant_id`, `cc_pair_id`, `outcome` | Indexing tasks completed per connector |
|
||||
| `onyx_indexing_task_duration_seconds` | Histogram | `task_name`, `source`, `tenant_id` | Indexing task duration by connector type |
|
||||
|
||||
`connector_name` is intentionally excluded from these push-based counters to avoid unbounded cardinality (it's a free-form user string). The pull-based collectors on the monitoring worker include it since they have bounded cardinality (one series per connector).
|
||||
`connector_name` is intentionally excluded from these per-task counters to avoid unbounded cardinality (it's a free-form user string).
|
||||
|
||||
### Connector Health Metrics (`onyx.server.metrics.connector_health_metrics`)
|
||||
|
||||
Push-based metrics emitted by docfetching and docprocessing workers at the point where connector state changes occur. Scales to any number of tenants (no schema iteration). Unlike the per-task counters above, these include `connector_name` because their cardinality is bounded by the number of connectors (one series per connector), not by the number of task executions.
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
| ----------------------------------------------- | ------- | --------------------------------------------------------------- | ------------------------------------------------------------- |
|
||||
| `onyx_index_attempt_transitions_total` | Counter | `tenant_id`, `source`, `cc_pair_id`, `connector_name`, `status` | Index attempt status transitions (in_progress, success, etc.) |
|
||||
| `onyx_connector_in_error_state` | Gauge | `tenant_id`, `source`, `cc_pair_id`, `connector_name` | Whether connector is in repeated error state (1=yes, 0=no) |
|
||||
| `onyx_connector_last_success_timestamp_seconds` | Gauge | `tenant_id`, `source`, `cc_pair_id`, `connector_name` | Unix timestamp of last successful indexing |
|
||||
| `onyx_connector_docs_indexed_total` | Counter | `tenant_id`, `source`, `cc_pair_id`, `connector_name` | Total documents indexed per connector (monotonic) |
|
||||
| `onyx_connector_indexing_errors_total` | Counter | `tenant_id`, `source`, `cc_pair_id`, `connector_name` | Total failed index attempts per connector (monotonic) |
|
||||
|
||||
### Pull-Based Collectors (`onyx.server.metrics.indexing_pipeline`)
|
||||
|
||||
Registered only in the **Monitoring** worker. Collectors query Redis/Postgres at scrape time with a 30-second TTL cache.
|
||||
Registered only in the **Monitoring** worker. Collectors query Redis at scrape time with a 30-second TTL cache and a 120-second timeout to prevent the `/metrics` endpoint from hanging.
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
| ------------------------------------ | ----- | ------- | ----------------------------------- |
|
||||
@@ -229,8 +241,6 @@ Registered only in the **Monitoring** worker. Collectors query Redis/Postgres at
|
||||
| `onyx_queue_unacked` | Gauge | `queue` | Unacknowledged messages per queue |
|
||||
| `onyx_queue_oldest_task_age_seconds` | Gauge | `queue` | Age of the oldest task in the queue |
|
||||
|
||||
Plus additional connector health, index attempt, and worker heartbeat metrics — see `indexing_pipeline.py` for the full list.
|
||||
|
||||
### Adding Metrics to a Worker
|
||||
|
||||
Currently only the docfetching and docprocessing workers have push-based task metrics wired up. To add metrics to another worker (e.g. heavy, light, primary):
|
||||
|
||||
@@ -15,6 +15,7 @@ type InteractiveStatefulVariant =
|
||||
| "select-heavy"
|
||||
| "select-card"
|
||||
| "select-tinted"
|
||||
| "select-input"
|
||||
| "select-filter"
|
||||
| "sidebar-heavy"
|
||||
| "sidebar-light";
|
||||
@@ -35,6 +36,7 @@ interface InteractiveStatefulProps
|
||||
* - `"select-heavy"` — tinted selected background (for list rows, model pickers)
|
||||
* - `"select-card"` — like select-heavy but filled state has a visible background (for cards/larger surfaces)
|
||||
* - `"select-tinted"` — like select-heavy but with a tinted rest background
|
||||
* - `"select-input"` — rests at neutral-00 (matches input bar), hover/open shows neutral-03 + border-01
|
||||
* - `"select-filter"` — like select-tinted for empty/filled; selected state uses inverted tint backgrounds and inverted text (for filter buttons)
|
||||
* - `"sidebar-heavy"` — sidebar navigation items: muted when unselected (text-03/text-02), bold when selected (text-04/text-03)
|
||||
* - `"sidebar-light"` — sidebar navigation items: uniformly muted across all states (text-02/text-02)
|
||||
|
||||
@@ -350,6 +350,41 @@
|
||||
--interactive-foreground-icon: var(--text-01);
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Select-Input — Empty
|
||||
Matches input bar background at rest, tints on hover/open.
|
||||
--------------------------------------------------------------------------- */
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"] {
|
||||
@apply bg-background-neutral-00;
|
||||
--interactive-foreground: var(--text-04);
|
||||
--interactive-foreground-icon: var(--text-03);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"]:hover:not(
|
||||
[data-disabled]
|
||||
),
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"][data-interaction="hover"]:not(
|
||||
[data-disabled]
|
||||
) {
|
||||
@apply bg-background-neutral-03;
|
||||
--interactive-foreground: var(--text-04);
|
||||
--interactive-foreground-icon: var(--text-03);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"]:active:not(
|
||||
[data-disabled]
|
||||
),
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"][data-interaction="active"]:not(
|
||||
[data-disabled]
|
||||
) {
|
||||
@apply bg-background-neutral-03;
|
||||
--interactive-foreground: var(--text-05);
|
||||
--interactive-foreground-icon: var(--text-05);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"][data-disabled] {
|
||||
@apply bg-transparent;
|
||||
--interactive-foreground: var(--text-01);
|
||||
--interactive-foreground-icon: var(--text-01);
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Select-Tinted — Filled
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
16
web/package-lock.json
generated
16
web/package-lock.json
generated
@@ -47,6 +47,7 @@
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.0.0",
|
||||
"cookies-next": "^5.1.0",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"date-fns": "^3.6.0",
|
||||
"docx-preview": "^0.3.7",
|
||||
"favicon-fetch": "^1.0.0",
|
||||
@@ -8843,6 +8844,15 @@
|
||||
"react": ">= 16.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/copy-to-clipboard": {
|
||||
"version": "3.3.3",
|
||||
"resolved": "https://registry.npmjs.org/copy-to-clipboard/-/copy-to-clipboard-3.3.3.tgz",
|
||||
"integrity": "sha512-2KV8NhB5JqC3ky0r9PMCAZKbUHSwtEo4CwCs0KXgruG43gX5PMqDEBbVU4OUzw2MuAWUfsuFmWvEKG5QRfSnJA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"toggle-selection": "^1.0.6"
|
||||
}
|
||||
},
|
||||
"node_modules/core-js": {
|
||||
"version": "3.46.0",
|
||||
"hasInstallScript": true,
|
||||
@@ -17426,6 +17436,12 @@
|
||||
"node": ">=8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/toggle-selection": {
|
||||
"version": "1.0.6",
|
||||
"resolved": "https://registry.npmjs.org/toggle-selection/-/toggle-selection-1.0.6.tgz",
|
||||
"integrity": "sha512-BiZS+C1OS8g/q2RRbJmy59xpyghNBqrr6k5L/uKBGRsTfxmu3ffiRnd8mlGPUVayg8pvfi5urfnu8TU7DVOkLQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/toposort": {
|
||||
"version": "2.0.2",
|
||||
"license": "MIT"
|
||||
|
||||
@@ -65,6 +65,7 @@
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.0.0",
|
||||
"cookies-next": "^5.1.0",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"date-fns": "^3.6.0",
|
||||
"docx-preview": "^0.3.7",
|
||||
"favicon-fetch": "^1.0.0",
|
||||
|
||||
@@ -17,6 +17,7 @@ import DocumentSetCard from "@/sections/cards/DocumentSetCard";
|
||||
import CollapsibleSection from "@/app/admin/agents/CollapsibleSection";
|
||||
import { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
|
||||
import { StandardAnswerCategoryDropdownField } from "@/components/standardAnswers/StandardAnswerCategoryDropdown";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import { RadioGroup } from "@/components/ui/radio-group";
|
||||
import { RadioGroupItemField } from "@/components/ui/RadioGroupItemField";
|
||||
import { AlertCircle } from "lucide-react";
|
||||
@@ -126,6 +127,24 @@ export function SlackChannelConfigFormFields({
|
||||
return documentSets.filter((ds) => !documentSetContainsSync(ds));
|
||||
}, [documentSets]);
|
||||
|
||||
const searchAgentOptions = useMemo(
|
||||
() =>
|
||||
availableAgents.map((persona) => ({
|
||||
label: persona.name,
|
||||
value: String(persona.id),
|
||||
})),
|
||||
[availableAgents]
|
||||
);
|
||||
|
||||
const nonSearchAgentOptions = useMemo(
|
||||
() =>
|
||||
nonSearchAgents.map((persona) => ({
|
||||
label: persona.name,
|
||||
value: String(persona.id),
|
||||
})),
|
||||
[nonSearchAgents]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const invalidSelected = values.document_sets.filter((dsId: number) =>
|
||||
unselectableSets.some((us) => us.id === dsId)
|
||||
@@ -355,12 +374,14 @@ export function SlackChannelConfigFormFields({
|
||||
</>
|
||||
</SubLabel>
|
||||
|
||||
<SelectorFormField
|
||||
name="persona_id"
|
||||
options={availableAgents.map((persona) => ({
|
||||
name: persona.name,
|
||||
value: persona.id,
|
||||
}))}
|
||||
<InputComboBox
|
||||
placeholder="Search for an agent..."
|
||||
value={String(values.persona_id ?? "")}
|
||||
onValueChange={(val) =>
|
||||
setFieldValue("persona_id", val ? Number(val) : null)
|
||||
}
|
||||
options={searchAgentOptions}
|
||||
strict
|
||||
/>
|
||||
{viewSyncEnabledAgents && syncEnabledAgents.length > 0 && (
|
||||
<div className="mt-4">
|
||||
@@ -419,12 +440,14 @@ export function SlackChannelConfigFormFields({
|
||||
</>
|
||||
</SubLabel>
|
||||
|
||||
<SelectorFormField
|
||||
name="persona_id"
|
||||
options={nonSearchAgents.map((persona) => ({
|
||||
name: persona.name,
|
||||
value: persona.id,
|
||||
}))}
|
||||
<InputComboBox
|
||||
placeholder="Search for an agent..."
|
||||
value={String(values.persona_id ?? "")}
|
||||
onValueChange={(val) =>
|
||||
setFieldValue("persona_id", val ? Number(val) : null)
|
||||
}
|
||||
options={nonSearchAgentOptions}
|
||||
strict
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -73,7 +73,10 @@ export const MemoizedAnchor = memo(
|
||||
: undefined;
|
||||
|
||||
if (!associatedDoc && !associatedSubQuestion) {
|
||||
return <>{children}</>;
|
||||
// Citation not resolved yet (data still streaming) — hide the
|
||||
// raw [[N]](url) link entirely. It will render as a chip once
|
||||
// the citation/document data arrives.
|
||||
return <></>;
|
||||
}
|
||||
|
||||
let icon: React.ReactNode = null;
|
||||
|
||||
@@ -44,6 +44,8 @@ export interface MultiModelPanelProps {
|
||||
errorStackTrace?: string | null;
|
||||
/** Additional error details */
|
||||
errorDetails?: Record<string, any> | null;
|
||||
/** Whether any model is still streaming — disables preferred selection */
|
||||
isGenerating?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -73,19 +75,24 @@ export default function MultiModelPanel({
|
||||
isRetryable,
|
||||
errorStackTrace,
|
||||
errorDetails,
|
||||
isGenerating,
|
||||
}: MultiModelPanelProps) {
|
||||
const ModelIcon = getModelIcon(provider, modelName);
|
||||
|
||||
const canSelect = !isHidden && !isPreferred && !isGenerating;
|
||||
|
||||
const handlePanelClick = useCallback(() => {
|
||||
if (!isHidden && !isPreferred) onSelect();
|
||||
}, [isHidden, isPreferred, onSelect]);
|
||||
if (canSelect) onSelect();
|
||||
}, [canSelect, onSelect]);
|
||||
|
||||
const header = (
|
||||
<div
|
||||
className={cn(
|
||||
"rounded-12",
|
||||
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00"
|
||||
"rounded-12 transition-colors",
|
||||
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00",
|
||||
canSelect && "cursor-pointer hover:bg-background-tint-02"
|
||||
)}
|
||||
onClick={handlePanelClick}
|
||||
>
|
||||
<ContentAction
|
||||
sizePreset="main-ui"
|
||||
@@ -140,13 +147,7 @@ export default function MultiModelPanel({
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-col gap-3 min-w-0 rounded-16 transition-colors",
|
||||
!isPreferred && "cursor-pointer hover:bg-background-tint-02"
|
||||
)}
|
||||
onClick={handlePanelClick}
|
||||
>
|
||||
<div className="flex flex-col gap-3 min-w-0 rounded-16">
|
||||
{header}
|
||||
{errorMessage ? (
|
||||
<div className="p-4">
|
||||
@@ -163,6 +164,7 @@ export default function MultiModelPanel({
|
||||
<AgentMessage
|
||||
{...agentMessageProps}
|
||||
hideFooter={isNonPreferredInSelection}
|
||||
disableTTS
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useCallback, useMemo, useEffect, useRef } from "react";
|
||||
import {
|
||||
useState,
|
||||
useCallback,
|
||||
useMemo,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useRef,
|
||||
} from "react";
|
||||
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
|
||||
import { Message } from "@/app/app/interfaces";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
@@ -110,11 +117,27 @@ export default function MultiModelResponseView({
|
||||
// Refs to each panel wrapper for height animation on deselect
|
||||
const panelElsRef = useRef<Map<number, HTMLDivElement>>(new Map());
|
||||
|
||||
// Tracks which non-preferred panels overflow the preferred height cap
|
||||
// Tracks which non-preferred panels overflow the preferred height cap.
|
||||
// Measured via useLayoutEffect after maxHeight is applied to the DOM —
|
||||
// ref callbacks fire before layout and can't reliably detect overflow.
|
||||
const [overflowingPanels, setOverflowingPanels] = useState<Set<number>>(
|
||||
new Set()
|
||||
);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
if (preferredPanelHeight == null || preferredIndex === null) return;
|
||||
const next = new Set<number>();
|
||||
panelElsRef.current.forEach((el, idx) => {
|
||||
if (idx === preferredIndex || hiddenPanels.has(idx)) return;
|
||||
if (el.scrollHeight > el.clientHeight) next.add(idx);
|
||||
});
|
||||
setOverflowingPanels((prev) => {
|
||||
if (prev.size === next.size && Array.from(prev).every((v) => next.has(v)))
|
||||
return prev;
|
||||
return next;
|
||||
});
|
||||
}, [preferredPanelHeight, preferredIndex, hiddenPanels, responses]);
|
||||
|
||||
const preferredPanelRef = useCallback((el: HTMLDivElement | null) => {
|
||||
if (preferredRoRef.current) {
|
||||
preferredRoRef.current.disconnect();
|
||||
@@ -210,8 +233,10 @@ export default function MultiModelResponseView({
|
||||
const response = responses.find((r) => r.modelIndex === modelIndex);
|
||||
if (!response) return;
|
||||
|
||||
// Persist preferred response to backend + update local tree so the
|
||||
// input bar unblocks (awaitingPreferredSelection clears).
|
||||
// Persist preferred response + sync `latestChildNodeId`. Backend's
|
||||
// `set_preferred_response` updates `latest_child_message_id`; if the
|
||||
// frontend chain walk disagrees, the next follow-up fails with
|
||||
// "not on the latest mainline".
|
||||
if (parentMessage?.messageId && response.messageId && currentSessionId) {
|
||||
setPreferredResponse(parentMessage.messageId, response.messageId).catch(
|
||||
(err) => console.error("Failed to persist preferred response:", err)
|
||||
@@ -227,6 +252,7 @@ export default function MultiModelResponseView({
|
||||
updated.set(parentMessage.nodeId, {
|
||||
...userMsg,
|
||||
preferredResponseId: response.messageId,
|
||||
latestChildNodeId: response.nodeId,
|
||||
});
|
||||
updateSessionMessageTree(currentSessionId, updated);
|
||||
}
|
||||
@@ -413,6 +439,7 @@ export default function MultiModelResponseView({
|
||||
isRetryable: response.isRetryable,
|
||||
errorStackTrace: response.errorStackTrace,
|
||||
errorDetails: response.errorDetails,
|
||||
isGenerating,
|
||||
}),
|
||||
[
|
||||
preferredIndex,
|
||||
@@ -426,6 +453,7 @@ export default function MultiModelResponseView({
|
||||
onMessageSelection,
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
isGenerating,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -512,17 +540,6 @@ export default function MultiModelResponseView({
|
||||
panelElsRef.current.delete(r.modelIndex);
|
||||
}
|
||||
if (isPref) preferredPanelRef(el);
|
||||
if (capped && el) {
|
||||
const doesOverflow = el.scrollHeight > el.clientHeight;
|
||||
setOverflowingPanels((prev) => {
|
||||
const had = prev.has(r.modelIndex);
|
||||
if (doesOverflow === had) return prev;
|
||||
const next = new Set(prev);
|
||||
if (doesOverflow) next.add(r.modelIndex);
|
||||
else next.delete(r.modelIndex);
|
||||
return next;
|
||||
});
|
||||
}
|
||||
}}
|
||||
style={{
|
||||
width: `${selectionEntered ? finalW : startW}px`,
|
||||
@@ -533,21 +550,19 @@ export default function MultiModelResponseView({
|
||||
: "none",
|
||||
maxHeight: capped ? preferredPanelHeight : undefined,
|
||||
overflow: capped ? "hidden" : undefined,
|
||||
position: capped ? "relative" : undefined,
|
||||
...(overflows
|
||||
? {
|
||||
maskImage:
|
||||
"linear-gradient(to bottom, black calc(100% - 6rem), transparent 100%)",
|
||||
WebkitMaskImage:
|
||||
"linear-gradient(to bottom, black calc(100% - 6rem), transparent 100%)",
|
||||
}
|
||||
: {}),
|
||||
}}
|
||||
>
|
||||
<div className={cn(isNonPref && "opacity-50")}>
|
||||
<MultiModelPanel {...buildPanelProps(r, isNonPref)} />
|
||||
</div>
|
||||
{overflows && (
|
||||
<div
|
||||
className="absolute inset-x-0 bottom-0 h-24 pointer-events-none"
|
||||
style={{
|
||||
background:
|
||||
"linear-gradient(to top, var(--background-tint-01) 0%, transparent 100%)",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
|
||||
@@ -1,3 +1,25 @@
|
||||
/* Map Tailwind Typography prose variables to the project's color tokens.
|
||||
These auto-switch for dark mode via colors.css — no dark: modifier needed.
|
||||
Note: text-05 = highest contrast, text-01 = lowest. */
|
||||
.prose-onyx {
|
||||
--tw-prose-body: var(--text-05);
|
||||
--tw-prose-headings: var(--text-05);
|
||||
--tw-prose-lead: var(--text-04);
|
||||
--tw-prose-links: var(--action-link-05);
|
||||
--tw-prose-bold: var(--text-05);
|
||||
--tw-prose-counters: var(--text-03);
|
||||
--tw-prose-bullets: var(--text-03);
|
||||
--tw-prose-hr: var(--border-02);
|
||||
--tw-prose-quotes: var(--text-04);
|
||||
--tw-prose-quote-borders: var(--border-02);
|
||||
--tw-prose-captions: var(--text-03);
|
||||
--tw-prose-code: var(--text-05);
|
||||
--tw-prose-pre-code: var(--text-04);
|
||||
--tw-prose-pre-bg: var(--background-code-01);
|
||||
--tw-prose-th-borders: var(--border-02);
|
||||
--tw-prose-td-borders: var(--border-01);
|
||||
}
|
||||
|
||||
/* Light mode syntax highlighting (Atom One Light) */
|
||||
.hljs {
|
||||
color: #383a42 !important;
|
||||
@@ -236,23 +258,102 @@ pre[class*="language-"] {
|
||||
scrollbar-color: #4b5563 #1f2937;
|
||||
}
|
||||
|
||||
/* Card wrapper — holds the background, border-radius, padding, and fade overlay.
|
||||
Does NOT scroll — the inner .markdown-table-breakout handles that. */
|
||||
.markdown-table-card {
|
||||
position: relative;
|
||||
background: var(--background-neutral-01);
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Table breakout container - allows tables to extend beyond their parent's
|
||||
* constrained width to use the full container query width (100cqw).
|
||||
*
|
||||
* Requires an ancestor element with `container-type: inline-size` (@container in Tailwind).
|
||||
*
|
||||
* How the math works:
|
||||
* - width: 100cqw → expand to full container query width
|
||||
* - marginLeft: calc((100% - 100cqw) / 2) → negative margin pulls element left
|
||||
* (100% is parent width, 100cqw is larger, so result is negative)
|
||||
* - paddingLeft/Right: calc((100cqw - 100%) / 2) → padding keeps content aligned
|
||||
* with original position while allowing scroll area to extend
|
||||
* Scrollable table container — sits inside the card.
|
||||
*/
|
||||
.markdown-table-breakout {
|
||||
overflow-x: auto;
|
||||
width: 100cqw;
|
||||
margin-left: calc((100% - 100cqw) / 2);
|
||||
padding-left: calc((100cqw - 100%) / 2);
|
||||
padding-right: calc((100cqw - 100%) / 2);
|
||||
|
||||
/* Always reserve scrollbar height so hover doesn't shift content.
|
||||
Thumb is transparent by default, revealed on hover. */
|
||||
scrollbar-width: thin; /* Firefox — always shows track */
|
||||
scrollbar-color: transparent transparent; /* invisible thumb + track */
|
||||
}
|
||||
.markdown-table-breakout::-webkit-scrollbar {
|
||||
height: 6px;
|
||||
}
|
||||
.markdown-table-breakout::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
.markdown-table-breakout::-webkit-scrollbar-thumb {
|
||||
background: transparent;
|
||||
border-radius: 3px;
|
||||
}
|
||||
.markdown-table-breakout:hover {
|
||||
scrollbar-color: var(--border-03) transparent; /* Firefox — reveal thumb */
|
||||
}
|
||||
.markdown-table-breakout:hover::-webkit-scrollbar-thumb {
|
||||
background: var(--border-03);
|
||||
}
|
||||
|
||||
/* Fade the right edge via an ::after overlay on the non-scrolling card.
|
||||
Stays pinned while table scrolls; doesn't affect the sticky column. */
|
||||
.markdown-table-card::after {
|
||||
content: "";
|
||||
position: absolute;
|
||||
top: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
width: 2rem;
|
||||
pointer-events: none;
|
||||
z-index: 2;
|
||||
background: linear-gradient(
|
||||
to right,
|
||||
transparent,
|
||||
var(--background-neutral-01)
|
||||
);
|
||||
border-radius: 0 0.5rem 0.5rem 0;
|
||||
opacity: 0;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
.markdown-table-card[data-overflows="true"]::after {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Sticky first column — inherits the container's background so it
|
||||
matches regardless of theme or custom wallpaper. */
|
||||
.markdown-table-breakout th:first-child,
|
||||
.markdown-table-breakout td:first-child {
|
||||
position: sticky;
|
||||
left: 0;
|
||||
z-index: 1;
|
||||
padding-left: 0.75rem;
|
||||
background: var(--background-neutral-01);
|
||||
}
|
||||
.markdown-table-breakout th:last-child,
|
||||
.markdown-table-breakout td:last-child {
|
||||
padding-right: 0.75rem;
|
||||
}
|
||||
|
||||
/* Shadow on sticky column when scrolled. Uses an ::after pseudo-element
|
||||
so it isn't clipped by the overflow container or the mask-image fade. */
|
||||
.markdown-table-breakout th:first-child::after,
|
||||
.markdown-table-breakout td:first-child::after {
|
||||
content: "";
|
||||
position: absolute;
|
||||
top: 0;
|
||||
right: -6px;
|
||||
bottom: 0;
|
||||
width: 6px;
|
||||
pointer-events: none;
|
||||
opacity: 0;
|
||||
transition: opacity 0.15s;
|
||||
box-shadow: inset 6px 0 8px -4px var(--alpha-grey-100-25);
|
||||
}
|
||||
.dark .markdown-table-breakout th:first-child::after,
|
||||
.dark .markdown-table-breakout td:first-child::after {
|
||||
box-shadow: inset 6px 0 8px -4px var(--alpha-grey-100-60);
|
||||
}
|
||||
.markdown-table-breakout[data-scrolled="true"] th:first-child::after,
|
||||
.markdown-table-breakout[data-scrolled="true"] td:first-child::after {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
@@ -51,6 +51,8 @@ export interface AgentMessageProps {
|
||||
processingDurationSeconds?: number;
|
||||
/** Hide the feedback/toolbar footer (used in multi-model non-preferred panels) */
|
||||
hideFooter?: boolean;
|
||||
/** Skip TTS streaming (used in multi-model where voice doesn't apply) */
|
||||
disableTTS?: boolean;
|
||||
}
|
||||
|
||||
// TODO: Consider more robust comparisons:
|
||||
@@ -99,6 +101,7 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
parentMessage,
|
||||
processingDurationSeconds,
|
||||
hideFooter,
|
||||
disableTTS,
|
||||
}: AgentMessageProps) {
|
||||
const markdownRef = useRef<HTMLDivElement>(null);
|
||||
const finalAnswerRef = useRef<HTMLDivElement>(null);
|
||||
@@ -133,32 +136,49 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
finalAnswerComing
|
||||
);
|
||||
|
||||
// Memoize merged citations separately to avoid creating new object when neither source changed
|
||||
// Merge streaming citation/document data with chatState props.
|
||||
// NOTE: citationMap and documentMap from usePacketProcessor are mutated in
|
||||
// place (same object reference), so we use citations.length / documentMap.size
|
||||
// as change-detection proxies to bust the memo cache when new data arrives.
|
||||
const mergedCitations = useMemo(
|
||||
() => ({
|
||||
...chatState.citations,
|
||||
...citationMap,
|
||||
}),
|
||||
[chatState.citations, citationMap]
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
[chatState.citations, citationMap, citations.length]
|
||||
);
|
||||
|
||||
// Create a chatState that uses streaming citations for immediate rendering
|
||||
// This merges the prop citations with streaming citations, preferring streaming ones
|
||||
// Memoized with granular dependencies to prevent cascading re-renders
|
||||
// Merge streaming documentMap into chatState.docs so inline citation chips
|
||||
// can resolve [1] → document even when chatState.docs is empty (multi-model).
|
||||
const mergedDocs = useMemo(() => {
|
||||
const propDocs = chatState.docs ?? [];
|
||||
if (documentMap.size === 0) return propDocs;
|
||||
const seen = new Set(propDocs.map((d) => d.document_id));
|
||||
const extras = Array.from(documentMap.values()).filter(
|
||||
(d) => !seen.has(d.document_id)
|
||||
);
|
||||
return extras.length > 0 ? [...propDocs, ...extras] : propDocs;
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [chatState.docs, documentMap, documentMap.size]);
|
||||
|
||||
// Create a chatState that uses streaming citations and documents for immediate rendering.
|
||||
// Memoized with granular dependencies to prevent cascading re-renders.
|
||||
// Note: chatState object is recreated upstream on every render, so we depend on
|
||||
// individual fields instead of the whole object for proper memoization
|
||||
// individual fields instead of the whole object for proper memoization.
|
||||
const effectiveChatState = useMemo<FullChatState>(
|
||||
() => ({
|
||||
...chatState,
|
||||
citations: mergedCitations,
|
||||
docs: mergedDocs,
|
||||
}),
|
||||
[
|
||||
chatState.agent,
|
||||
chatState.docs,
|
||||
chatState.setPresentingDocument,
|
||||
chatState.overriddenModel,
|
||||
chatState.researchType,
|
||||
mergedCitations,
|
||||
mergedDocs,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -202,6 +222,9 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
// Skip if we've already finished TTS for this message
|
||||
if (ttsCompletedRef.current) return;
|
||||
|
||||
// Multi-model: skip TTS entirely
|
||||
if (disableTTS) return;
|
||||
|
||||
// If user cancelled generation, do not send more text to TTS.
|
||||
if (stopPacketSeen && stopReason === StopReason.USER_CANCELLED) {
|
||||
ttsCompletedRef.current = true;
|
||||
@@ -305,7 +328,7 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
onRenderComplete();
|
||||
}
|
||||
}}
|
||||
animate={false}
|
||||
animate={!stopPacketSeen}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
>
|
||||
|
||||
@@ -59,7 +59,6 @@ function TTSButton({ text, voice, speed }: TTSButtonProps) {
|
||||
// Surface streaming voice playback errors to the user via toast
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
console.error("Voice playback error:", error);
|
||||
toast.error(error);
|
||||
}
|
||||
}, [error]);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useCallback, useMemo, JSX } from "react";
|
||||
import React, { useCallback, useEffect, useRef, useMemo, JSX } from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkMath from "remark-math";
|
||||
@@ -17,10 +17,79 @@ import { transformLinkUri, cn } from "@/lib/utils";
|
||||
import { InMessageImage } from "@/app/app/components/files/images/InMessageImage";
|
||||
import { extractChatImageFileId } from "@/app/app/components/files/images/utils";
|
||||
|
||||
/** Table wrapper that detects horizontal overflow and shows a fade + scrollbar. */
|
||||
interface ScrollableTableProps
|
||||
extends React.TableHTMLAttributes<HTMLTableElement> {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export function ScrollableTable({
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}: ScrollableTableProps) {
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
const wrapRef = useRef<HTMLDivElement>(null);
|
||||
const tableRef = useRef<HTMLTableElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const el = scrollRef.current;
|
||||
const wrap = wrapRef.current;
|
||||
const table = tableRef.current;
|
||||
if (!el || !wrap) return;
|
||||
|
||||
const check = () => {
|
||||
const overflows = el.scrollWidth > el.clientWidth;
|
||||
const atEnd = el.scrollLeft + el.clientWidth >= el.scrollWidth - 2;
|
||||
wrap.dataset.overflows = overflows && !atEnd ? "true" : "false";
|
||||
el.dataset.scrolled = el.scrollLeft > 0 ? "true" : "false";
|
||||
};
|
||||
|
||||
check();
|
||||
el.addEventListener("scroll", check, { passive: true });
|
||||
// Observe both the scroll container (parent resize) and the table
|
||||
// itself (content growth during streaming).
|
||||
const ro = new ResizeObserver(check);
|
||||
ro.observe(el);
|
||||
if (table) ro.observe(table);
|
||||
|
||||
return () => {
|
||||
el.removeEventListener("scroll", check);
|
||||
ro.disconnect();
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div ref={wrapRef} className="markdown-table-card">
|
||||
<div ref={scrollRef} className="markdown-table-breakout">
|
||||
<table
|
||||
ref={tableRef}
|
||||
className={cn(
|
||||
className,
|
||||
"min-w-full !my-0 [&_th]:whitespace-nowrap [&_td]:whitespace-nowrap"
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes content for markdown rendering by handling code blocks and LaTeX
|
||||
*/
|
||||
export const processContent = (content: string): string => {
|
||||
// Strip incomplete citation links at the end of streaming content.
|
||||
// During typewriter animation, [[N]](url) is revealed character by character.
|
||||
// ReactMarkdown can't parse an incomplete link and renders it as raw text.
|
||||
// This regex removes any trailing partial citation pattern so only complete
|
||||
// links are passed to the markdown parser.
|
||||
content = content.replace(/\[\[\d+\]\]\([^)]*$/, "");
|
||||
// Also strip a lone [[ or [[N] or [[N]] at the very end (before the URL part arrives)
|
||||
content = content.replace(/\[\[(?:\d+\]?\]?)?$/, "");
|
||||
|
||||
const codeBlockRegex = /```(\w*)\n[\s\S]*?```|```[\s\S]*?$/g;
|
||||
const matches = content.match(codeBlockRegex);
|
||||
|
||||
@@ -127,11 +196,9 @@ export const useMarkdownComponents = (
|
||||
},
|
||||
table: ({ node, className, children, ...props }: any) => {
|
||||
return (
|
||||
<div className="markdown-table-breakout">
|
||||
<table className={cn(className, "min-w-full")} {...props}>
|
||||
{children}
|
||||
</table>
|
||||
</div>
|
||||
<ScrollableTable className={className} {...props}>
|
||||
{children}
|
||||
</ScrollableTable>
|
||||
);
|
||||
},
|
||||
code: ({ node, className, children }: any) => {
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
import React, { useEffect, useMemo, useRef, useState } from "react";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import ReactMarkdown, { Components } from "react-markdown";
|
||||
import type { PluggableList } from "unified";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkMath from "remark-math";
|
||||
import rehypeHighlight from "rehype-highlight";
|
||||
import rehypeKatex from "rehype-katex";
|
||||
import "katex/dist/katex.min.css";
|
||||
|
||||
import { useTypewriter } from "@/hooks/useTypewriter";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import {
|
||||
ChatPacket,
|
||||
PacketType,
|
||||
@@ -8,16 +16,22 @@ import {
|
||||
} from "../../../services/streamingModels";
|
||||
import { MessageRenderer, FullChatState } from "../interfaces";
|
||||
import { isFinalAnswerComplete } from "../../../services/packetUtils";
|
||||
import { useMarkdownRenderer } from "../markdownUtils";
|
||||
import { processContent } from "../markdownUtils";
|
||||
import { BlinkingBar } from "../../BlinkingBar";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import {
|
||||
MemoizedAnchor,
|
||||
MemoizedParagraph,
|
||||
} from "@/app/app/message/MemoizedTextComponents";
|
||||
import { extractCodeText } from "@/app/app/message/codeUtils";
|
||||
import { CodeBlock } from "@/app/app/message/CodeBlock";
|
||||
import { InMessageImage } from "@/app/app/components/files/images/InMessageImage";
|
||||
import { extractChatImageFileId } from "@/app/app/components/files/images/utils";
|
||||
import { cn, transformLinkUri } from "@/lib/utils";
|
||||
|
||||
/**
|
||||
* Maps a cleaned character position to the corresponding position in markdown text.
|
||||
* This allows progressive reveal to work with markdown formatting.
|
||||
*/
|
||||
/** Maps a visible-char count to a markdown index (skips formatting chars,
|
||||
* extends to word boundary). Used by the voice-sync reveal path only. */
|
||||
function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
// Skip patterns that don't contribute to visible character count
|
||||
const skipChars = new Set(["*", "`", "#"]);
|
||||
let cleanIndex = 0;
|
||||
let mdIndex = 0;
|
||||
@@ -25,13 +39,11 @@ function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
while (cleanIndex < cleanChars && mdIndex < markdown.length) {
|
||||
const char = markdown[mdIndex];
|
||||
|
||||
// Skip markdown formatting characters
|
||||
if (char !== undefined && skipChars.has(char)) {
|
||||
mdIndex++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle link syntax [text](url) - skip the (url) part but count the text
|
||||
if (
|
||||
char === "]" &&
|
||||
mdIndex + 1 < markdown.length &&
|
||||
@@ -48,7 +60,6 @@ function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
mdIndex++;
|
||||
}
|
||||
|
||||
// Extend to word boundary to avoid cutting mid-word
|
||||
while (
|
||||
mdIndex < markdown.length &&
|
||||
markdown[mdIndex] !== " " &&
|
||||
@@ -60,8 +71,15 @@ function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
return mdIndex;
|
||||
}
|
||||
|
||||
// Control the rate of packet streaming (packets per second)
|
||||
const PACKET_DELAY_MS = 10;
|
||||
// Cheap streaming plugins (gfm only) → cheap per-frame parse. Full
|
||||
// pipeline flips in once, at the end, for syntax highlighting + math.
|
||||
const STREAMING_REMARK_PLUGINS: PluggableList = [remarkGfm];
|
||||
const STREAMING_REHYPE_PLUGINS: PluggableList = [];
|
||||
const FULL_REMARK_PLUGINS: PluggableList = [
|
||||
remarkGfm,
|
||||
[remarkMath, { singleDollarTextMath: true }],
|
||||
];
|
||||
const FULL_REHYPE_PLUGINS: PluggableList = [rehypeHighlight, rehypeKatex];
|
||||
|
||||
export const MessageTextRenderer: MessageRenderer<
|
||||
ChatPacket,
|
||||
@@ -78,19 +96,17 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
stopReason,
|
||||
children,
|
||||
}) => {
|
||||
// If we're animating and the final answer is already complete, show more packets initially
|
||||
const initialPacketCount = animate
|
||||
? packets.length > 0
|
||||
? 1 // Otherwise start with 1 packet
|
||||
: 0
|
||||
: -1; // Show all if not animating
|
||||
|
||||
const [displayedPacketCount, setDisplayedPacketCount] =
|
||||
useState(initialPacketCount);
|
||||
const lastStableSyncedContentRef = useRef("");
|
||||
const lastVisibleContentRef = useRef("");
|
||||
|
||||
// Get voice mode context for progressive text reveal synced with audio
|
||||
// Timeout guard: if TTS doesn't start within 5s of voice sync
|
||||
// activating, fall back to normal streaming. Prevents permanent
|
||||
// content suppression when the voice WebSocket fails to connect.
|
||||
const [voiceSyncTimedOut, setVoiceSyncTimedOut] = useState(false);
|
||||
const voiceSyncTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(
|
||||
null
|
||||
);
|
||||
|
||||
const {
|
||||
revealedCharCount,
|
||||
autoPlayback,
|
||||
@@ -99,7 +115,6 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
isAwaitingAutoPlaybackStart,
|
||||
} = useVoiceMode();
|
||||
|
||||
// Get the full content from all packets
|
||||
const fullContent = packets
|
||||
.map((packet) => {
|
||||
if (
|
||||
@@ -114,117 +129,74 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
|
||||
const shouldUseAutoPlaybackSync =
|
||||
autoPlayback &&
|
||||
!voiceSyncTimedOut &&
|
||||
typeof messageNodeId === "number" &&
|
||||
activeMessageNodeId === messageNodeId;
|
||||
|
||||
// Animation effect - gradually increase displayed packets at controlled rate
|
||||
// Start/clear the timeout when voice sync activates/deactivates.
|
||||
useEffect(() => {
|
||||
if (!animate) {
|
||||
setDisplayedPacketCount(-1); // Show all packets
|
||||
return;
|
||||
}
|
||||
|
||||
if (displayedPacketCount >= 0 && displayedPacketCount < packets.length) {
|
||||
const timer = setTimeout(() => {
|
||||
setDisplayedPacketCount((prev) => Math.min(prev + 1, packets.length));
|
||||
}, PACKET_DELAY_MS);
|
||||
|
||||
return () => clearTimeout(timer);
|
||||
}
|
||||
}, [animate, displayedPacketCount, packets.length]);
|
||||
|
||||
// Reset displayed count when packet array changes significantly (e.g., new message)
|
||||
useEffect(() => {
|
||||
if (animate && packets.length < displayedPacketCount) {
|
||||
const resetCount = isFinalAnswerComplete(packets)
|
||||
? Math.min(10, packets.length)
|
||||
: packets.length > 0
|
||||
? 1
|
||||
: 0;
|
||||
setDisplayedPacketCount(resetCount);
|
||||
}
|
||||
}, [animate, packets.length, displayedPacketCount]);
|
||||
|
||||
// Only mark as complete when all packets are received AND displayed
|
||||
useEffect(() => {
|
||||
if (isFinalAnswerComplete(packets)) {
|
||||
// If animating, wait until all packets are displayed
|
||||
if (
|
||||
animate &&
|
||||
displayedPacketCount >= 0 &&
|
||||
displayedPacketCount < packets.length
|
||||
) {
|
||||
return;
|
||||
if (shouldUseAutoPlaybackSync && isAwaitingAutoPlaybackStart) {
|
||||
if (!voiceSyncTimeoutRef.current) {
|
||||
voiceSyncTimeoutRef.current = setTimeout(() => {
|
||||
setVoiceSyncTimedOut(true);
|
||||
}, 5000);
|
||||
}
|
||||
onComplete();
|
||||
} else {
|
||||
// TTS started or sync deactivated — clear timeout
|
||||
if (voiceSyncTimeoutRef.current) {
|
||||
clearTimeout(voiceSyncTimeoutRef.current);
|
||||
voiceSyncTimeoutRef.current = null;
|
||||
}
|
||||
if (voiceSyncTimedOut && !autoPlayback) setVoiceSyncTimedOut(false);
|
||||
}
|
||||
}, [packets, onComplete, animate, displayedPacketCount]);
|
||||
return () => {
|
||||
if (voiceSyncTimeoutRef.current) {
|
||||
clearTimeout(voiceSyncTimeoutRef.current);
|
||||
voiceSyncTimeoutRef.current = null;
|
||||
}
|
||||
};
|
||||
}, [
|
||||
shouldUseAutoPlaybackSync,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
isAudioSyncActive,
|
||||
voiceSyncTimedOut,
|
||||
]);
|
||||
|
||||
// Get content based on displayed packet count or audio progress
|
||||
// Normal streaming hands full text to the typewriter. Voice-sync
|
||||
// paths pre-slice and bypass. If shouldUseAutoPlaybackSync is false
|
||||
// (including after the 5s timeout), all paths fall through to fullContent.
|
||||
const computedContent = useMemo(() => {
|
||||
// Hold response in "thinking" state only while autoplay startup is pending.
|
||||
if (shouldUseAutoPlaybackSync && isAwaitingAutoPlaybackStart) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Sync text with audio only for the message currently being spoken.
|
||||
if (shouldUseAutoPlaybackSync && isAudioSyncActive) {
|
||||
const MIN_REVEAL_CHARS = 12;
|
||||
if (revealedCharCount < MIN_REVEAL_CHARS) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Reveal text progressively based on audio progress
|
||||
const revealPos = getRevealPosition(fullContent, revealedCharCount);
|
||||
return fullContent.slice(0, Math.max(revealPos, 0));
|
||||
}
|
||||
|
||||
// During an active synced turn, if sync temporarily drops, keep current reveal
|
||||
// instead of jumping to full content or blanking.
|
||||
if (shouldUseAutoPlaybackSync && !stopPacketSeen) {
|
||||
return lastStableSyncedContentRef.current;
|
||||
}
|
||||
|
||||
// Standard behavior when auto-playback is off
|
||||
if (!animate || displayedPacketCount === -1) {
|
||||
return fullContent; // Show all content
|
||||
}
|
||||
|
||||
// Packet-based reveal (when auto-playback is disabled)
|
||||
return packets
|
||||
.slice(0, displayedPacketCount)
|
||||
.map((packet) => {
|
||||
if (
|
||||
packet.obj.type === PacketType.MESSAGE_DELTA ||
|
||||
packet.obj.type === PacketType.MESSAGE_START
|
||||
) {
|
||||
return packet.obj.content;
|
||||
}
|
||||
return "";
|
||||
})
|
||||
.join("");
|
||||
return fullContent;
|
||||
}, [
|
||||
animate,
|
||||
displayedPacketCount,
|
||||
fullContent,
|
||||
packets,
|
||||
revealedCharCount,
|
||||
autoPlayback,
|
||||
isAudioSyncActive,
|
||||
activeMessageNodeId,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
messageNodeId,
|
||||
shouldUseAutoPlaybackSync,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
isAudioSyncActive,
|
||||
revealedCharCount,
|
||||
fullContent,
|
||||
stopPacketSeen,
|
||||
]);
|
||||
|
||||
// Keep synced text monotonic: once visible, never regress or disappear between chunks.
|
||||
// Monotonic guard for voice sync + freeze on user cancel.
|
||||
const content = useMemo(() => {
|
||||
const wasUserCancelled = stopReason === StopReason.USER_CANCELLED;
|
||||
|
||||
// On user cancel during live streaming, freeze at exactly what was already
|
||||
// visible to prevent flicker. On history reload (animate=false), the ref
|
||||
// starts empty so we must use computedContent directly.
|
||||
if (wasUserCancelled && animate) {
|
||||
return lastVisibleContentRef.current;
|
||||
}
|
||||
@@ -242,13 +214,10 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
return computedContent;
|
||||
}
|
||||
|
||||
// If content shape changed unexpectedly mid-stream, prefer the stable version
|
||||
// to avoid flicker/dumps.
|
||||
if (!stopPacketSeen || wasUserCancelled) {
|
||||
return last;
|
||||
}
|
||||
|
||||
// For normal completed responses, allow final full content.
|
||||
return computedContent;
|
||||
}, [
|
||||
computedContent,
|
||||
@@ -258,7 +227,6 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
animate,
|
||||
]);
|
||||
|
||||
// Sync the stable ref outside of useMemo to avoid side effects during render.
|
||||
useEffect(() => {
|
||||
if (stopReason === StopReason.USER_CANCELLED) {
|
||||
return;
|
||||
@@ -270,13 +238,128 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
}
|
||||
}, [content, shouldUseAutoPlaybackSync, stopReason]);
|
||||
|
||||
// Track last actually rendered content so cancel can freeze without dumping buffered text.
|
||||
useEffect(() => {
|
||||
if (content.length > 0) {
|
||||
lastVisibleContentRef.current = content;
|
||||
}
|
||||
}, [content]);
|
||||
|
||||
const isStreamingAnimationEnabled =
|
||||
animate &&
|
||||
!shouldUseAutoPlaybackSync &&
|
||||
stopReason !== StopReason.USER_CANCELLED;
|
||||
|
||||
const isStreamFinished = isFinalAnswerComplete(packets);
|
||||
|
||||
const displayedContent = useTypewriter(content, isStreamingAnimationEnabled);
|
||||
|
||||
// One-way signal: stream done AND typewriter caught up. Do NOT derive
|
||||
// this from "typewriter currently behind" — it oscillates mid-stream
|
||||
// between packet bursts and would thrash the plugin pipeline.
|
||||
const streamFullyDisplayed =
|
||||
isStreamFinished && displayedContent.length >= content.length;
|
||||
|
||||
// Fire onComplete exactly once per mount. `onComplete` is an inline
|
||||
// arrow in AgentMessage so its identity changes on every parent render;
|
||||
// without this guard, each new identity would re-fire the effect once
|
||||
// `streamFullyDisplayed` is true.
|
||||
const onCompleteFiredRef = useRef(false);
|
||||
useEffect(() => {
|
||||
if (streamFullyDisplayed && !onCompleteFiredRef.current) {
|
||||
onCompleteFiredRef.current = true;
|
||||
onComplete();
|
||||
}
|
||||
}, [streamFullyDisplayed, onComplete]);
|
||||
|
||||
const processedContent = useMemo(
|
||||
() => processContent(displayedContent),
|
||||
[displayedContent]
|
||||
);
|
||||
|
||||
// Stable-identity components for ReactMarkdown. Dynamic data (`state`,
|
||||
// `processedContent`) flows through refs so the callback identities
|
||||
// never change — otherwise every typewriter tick would invalidate
|
||||
// React reconciliation on the markdown subtree.
|
||||
const stateRef = useRef(state);
|
||||
stateRef.current = state;
|
||||
const processedContentRef = useRef(processedContent);
|
||||
processedContentRef.current = processedContent;
|
||||
|
||||
const markdownComponents = useMemo<Components>(
|
||||
() => ({
|
||||
a: ({ href, children }) => {
|
||||
const s = stateRef.current;
|
||||
const imageFileId = extractChatImageFileId(
|
||||
href,
|
||||
String(children ?? "")
|
||||
);
|
||||
if (imageFileId) {
|
||||
return (
|
||||
<InMessageImage
|
||||
fileId={imageFileId}
|
||||
fileName={String(children ?? "")}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<MemoizedAnchor
|
||||
updatePresentingDocument={s?.setPresentingDocument || (() => {})}
|
||||
docs={s?.docs || []}
|
||||
userFiles={s?.userFiles || []}
|
||||
citations={s?.citations}
|
||||
href={href}
|
||||
>
|
||||
{children}
|
||||
</MemoizedAnchor>
|
||||
);
|
||||
},
|
||||
p: ({ children }) => (
|
||||
<MemoizedParagraph className="font-main-content-body">
|
||||
{children}
|
||||
</MemoizedParagraph>
|
||||
),
|
||||
pre: ({ children }) => <>{children}</>,
|
||||
b: ({ className, children }) => (
|
||||
<span className={className}>{children}</span>
|
||||
),
|
||||
ul: ({ className, children, ...rest }) => (
|
||||
<ul className={className} {...rest}>
|
||||
{children}
|
||||
</ul>
|
||||
),
|
||||
ol: ({ className, children, ...rest }) => (
|
||||
<ol className={className} {...rest}>
|
||||
{children}
|
||||
</ol>
|
||||
),
|
||||
li: ({ className, children, ...rest }) => (
|
||||
<li className={className} {...rest}>
|
||||
{children}
|
||||
</li>
|
||||
),
|
||||
table: ({ className, children, ...rest }) => (
|
||||
<div className="markdown-table-breakout">
|
||||
<table className={cn(className, "min-w-full")} {...rest}>
|
||||
{children}
|
||||
</table>
|
||||
</div>
|
||||
),
|
||||
code: ({ node, className, children }) => {
|
||||
const codeText = extractCodeText(
|
||||
node,
|
||||
processedContentRef.current,
|
||||
children
|
||||
);
|
||||
return (
|
||||
<CodeBlock className={className} codeText={codeText}>
|
||||
{children}
|
||||
</CodeBlock>
|
||||
);
|
||||
},
|
||||
}),
|
||||
[]
|
||||
);
|
||||
|
||||
const shouldShowThinkingPlaceholder =
|
||||
shouldUseAutoPlaybackSync &&
|
||||
isAwaitingAutoPlaybackStart &&
|
||||
@@ -292,16 +375,16 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
!stopPacketSeen;
|
||||
|
||||
const shouldShowCursor =
|
||||
content.length > 0 &&
|
||||
(!stopPacketSeen ||
|
||||
displayedContent.length > 0 &&
|
||||
((isStreamingAnimationEnabled && !streamFullyDisplayed) ||
|
||||
(!isStreamingAnimationEnabled && !stopPacketSeen) ||
|
||||
(shouldUseAutoPlaybackSync && content.length < fullContent.length));
|
||||
|
||||
const { renderedContent } = useMarkdownRenderer(
|
||||
// the [*]() is a hack to show a blinking dot when the packet is not complete
|
||||
shouldShowCursor ? content + " [*]() " : content,
|
||||
state,
|
||||
"font-main-content-body"
|
||||
);
|
||||
// `[*]() ` is rendered by the anchor component as an inline blinking
|
||||
// caret, keeping it flush with the trailing character.
|
||||
const markdownInput = shouldShowCursor
|
||||
? processedContent + " [*]() "
|
||||
: processedContent;
|
||||
|
||||
return children([
|
||||
{
|
||||
@@ -312,8 +395,26 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
<Text as="span" secondaryBody text04 className="italic">
|
||||
Thinking
|
||||
</Text>
|
||||
) : content.length > 0 ? (
|
||||
<>{renderedContent}</>
|
||||
) : displayedContent.length > 0 ? (
|
||||
<div dir="auto">
|
||||
<ReactMarkdown
|
||||
className="prose prose-onyx font-main-content-body max-w-full"
|
||||
components={markdownComponents}
|
||||
remarkPlugins={
|
||||
streamFullyDisplayed
|
||||
? FULL_REMARK_PLUGINS
|
||||
: STREAMING_REMARK_PLUGINS
|
||||
}
|
||||
rehypePlugins={
|
||||
streamFullyDisplayed
|
||||
? FULL_REHYPE_PLUGINS
|
||||
: STREAMING_REHYPE_PLUGINS
|
||||
}
|
||||
urlTransform={transformLinkUri}
|
||||
>
|
||||
{markdownInput}
|
||||
</ReactMarkdown>
|
||||
</div>
|
||||
) : (
|
||||
<BlinkingBar addMargin />
|
||||
),
|
||||
|
||||
@@ -34,7 +34,8 @@ export const PROVIDERS: ProviderConfig[] = [
|
||||
providerName: LLMProviderName.ANTHROPIC,
|
||||
recommended: true,
|
||||
models: [
|
||||
{ name: "claude-opus-4-6", label: "Claude Opus 4.6", recommended: true },
|
||||
{ name: "claude-opus-4-7", label: "Claude Opus 4.7", recommended: true },
|
||||
{ name: "claude-opus-4-6", label: "Claude Opus 4.6" },
|
||||
{ name: "claude-sonnet-4-6", label: "Claude Sonnet 4.6" },
|
||||
],
|
||||
apiKeyPlaceholder: "sk-ant-...",
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
export interface BuildLlmSelection {
|
||||
providerName: string; // e.g., "build-mode-anthropic" (LLMProviderDescriptor.name)
|
||||
provider: string; // e.g., "anthropic"
|
||||
modelName: string; // e.g., "claude-opus-4-6"
|
||||
modelName: string; // e.g., "claude-opus-4-7"
|
||||
}
|
||||
|
||||
// Priority order for smart default LLM selection
|
||||
const LLM_SELECTION_PRIORITY = [
|
||||
{ provider: "anthropic", modelName: "claude-opus-4-6" },
|
||||
{ provider: "anthropic", modelName: "claude-opus-4-7" },
|
||||
{ provider: "openai", modelName: "gpt-5.2" },
|
||||
{ provider: "openrouter", modelName: "minimax/minimax-m2.1" },
|
||||
] as const;
|
||||
@@ -63,10 +63,11 @@ export function getDefaultLlmSelection(
|
||||
export const RECOMMENDED_BUILD_MODELS = {
|
||||
preferred: {
|
||||
provider: "anthropic",
|
||||
modelName: "claude-opus-4-6",
|
||||
displayName: "Claude Opus 4.6",
|
||||
modelName: "claude-opus-4-7",
|
||||
displayName: "Claude Opus 4.7",
|
||||
},
|
||||
alternatives: [
|
||||
{ provider: "anthropic", modelName: "claude-opus-4-6" },
|
||||
{ provider: "anthropic", modelName: "claude-sonnet-4-6" },
|
||||
{ provider: "openai", modelName: "gpt-5.2" },
|
||||
{ provider: "openai", modelName: "gpt-5.1-codex" },
|
||||
@@ -148,7 +149,8 @@ export const BUILD_MODE_PROVIDERS: BuildModeProvider[] = [
|
||||
providerName: "anthropic",
|
||||
recommended: true,
|
||||
models: [
|
||||
{ name: "claude-opus-4-6", label: "Claude Opus 4.6", recommended: true },
|
||||
{ name: "claude-opus-4-7", label: "Claude Opus 4.7", recommended: true },
|
||||
{ name: "claude-opus-4-6", label: "Claude Opus 4.6" },
|
||||
{ name: "claude-sonnet-4-6", label: "Claude Sonnet 4.6" },
|
||||
],
|
||||
apiKeyPlaceholder: "sk-ant-...",
|
||||
|
||||
@@ -271,6 +271,22 @@ export default function UserLibraryModal({
|
||||
/>
|
||||
</Section>
|
||||
|
||||
{/* The exact cap is controlled by the backend env var
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE (default 500). This copy is
|
||||
deliberately vague so it doesn't drift if the limit is
|
||||
tuned per-deployment; the precise number is surfaced in
|
||||
the rejection error the server returns. */}
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="end"
|
||||
padding={0.5}
|
||||
height="fit"
|
||||
>
|
||||
<Text secondaryBody text03>
|
||||
PDFs with many embedded images may be rejected.
|
||||
</Text>
|
||||
</Section>
|
||||
|
||||
{isLoading ? (
|
||||
<Section padding={2} height="fit">
|
||||
<Text secondaryBody text03>
|
||||
|
||||
@@ -320,7 +320,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
onSubmit({
|
||||
message: submittedMessage,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled,
|
||||
deepResearch: deepResearchEnabled && !multiModel.isMultiModelActive,
|
||||
additionalContext,
|
||||
selectedModels,
|
||||
});
|
||||
@@ -332,7 +332,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
onSubmit({
|
||||
message: chatMessage,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled,
|
||||
deepResearch: deepResearchEnabled && !multiModel.isMultiModelActive,
|
||||
additionalContext,
|
||||
selectedModels,
|
||||
});
|
||||
@@ -370,10 +370,16 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
onSubmit({
|
||||
message: lastUserMsg.message,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled,
|
||||
deepResearch: deepResearchEnabled && !multiModel.isMultiModelActive,
|
||||
messageIdToResend: lastUserMsg.messageId,
|
||||
});
|
||||
}, [messageHistory, onSubmit, currentMessageFiles, deepResearchEnabled]);
|
||||
}, [
|
||||
messageHistory,
|
||||
onSubmit,
|
||||
currentMessageFiles,
|
||||
deepResearchEnabled,
|
||||
multiModel.isMultiModelActive,
|
||||
]);
|
||||
|
||||
// Start a new chat session in the side panel
|
||||
const handleNewChat = useCallback(() => {
|
||||
@@ -516,8 +522,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
ref={inputRef}
|
||||
className={cn(
|
||||
"w-full flex flex-col",
|
||||
!isSidePanel &&
|
||||
"max-w-[var(--app-page-main-content-width)] px-4"
|
||||
!isSidePanel && "max-w-[var(--app-page-main-content-width)]"
|
||||
)}
|
||||
>
|
||||
{hasMessages && liveAgent && !llmManager.isLoadingProviders && (
|
||||
@@ -535,6 +540,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
ref={chatInputBarRef}
|
||||
deepResearchEnabled={deepResearchEnabled}
|
||||
toggleDeepResearch={toggleDeepResearch}
|
||||
isMultiModelActive={multiModel.isMultiModelActive}
|
||||
filterManager={filterManager}
|
||||
llmManager={llmManager}
|
||||
initialMessage={message}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { SourceIcon } from "./SourceIcon";
|
||||
import { useState } from "react";
|
||||
import { OnyxIcon } from "./icons/icons";
|
||||
import { GithubIcon, OnyxIcon } from "./icons/icons";
|
||||
|
||||
export function WebResultIcon({
|
||||
url,
|
||||
@@ -23,6 +23,8 @@ export function WebResultIcon({
|
||||
<>
|
||||
{hostname.includes("onyx.app") ? (
|
||||
<OnyxIcon size={size} className="dark:text-[#fff] text-[#000]" />
|
||||
) : hostname === "github.com" || hostname.endsWith(".github.com") ? (
|
||||
<GithubIcon size={size} />
|
||||
) : !error ? (
|
||||
<img
|
||||
className="my-0 rounded-full py-0"
|
||||
|
||||
@@ -46,6 +46,7 @@ import freshdeskIcon from "@public/Freshdesk.png";
|
||||
import geminiSVG from "@public/Gemini.svg";
|
||||
import gitbookDarkIcon from "@public/GitBookDark.png";
|
||||
import gitbookLightIcon from "@public/GitBookLight.png";
|
||||
import githubDarkIcon from "@public/GithubDarkMode.png";
|
||||
import githubLightIcon from "@public/Github.png";
|
||||
import gongIcon from "@public/Gong.png";
|
||||
import googleIcon from "@public/Google.png";
|
||||
@@ -855,7 +856,7 @@ export const GitbookIcon = createLogoIcon(gitbookDarkIcon, {
|
||||
darkSrc: gitbookLightIcon,
|
||||
});
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true,
|
||||
darkSrc: githubDarkIcon,
|
||||
});
|
||||
export const GitlabIcon = createLogoIcon(gitlabIcon);
|
||||
export const GmailIcon = createLogoIcon(gmailIcon);
|
||||
|
||||
@@ -12,9 +12,9 @@ interface LLMOption {
|
||||
value: string;
|
||||
icon: ReturnType<typeof getModelIcon>;
|
||||
modelName: string;
|
||||
providerId: number;
|
||||
providerName: string;
|
||||
provider: string;
|
||||
providerDisplayName: string;
|
||||
supportsImageInput: boolean;
|
||||
vendor: string | null;
|
||||
}
|
||||
@@ -64,7 +64,7 @@ export default function LLMSelector({
|
||||
return;
|
||||
}
|
||||
|
||||
const key = `${provider.provider}:${modelConfiguration.name}`;
|
||||
const key = `${provider.id}:${modelConfiguration.name}`;
|
||||
if (seenKeys.has(key)) {
|
||||
return; // Skip exact duplicate
|
||||
}
|
||||
@@ -87,10 +87,9 @@ export default function LLMSelector({
|
||||
),
|
||||
icon: getModelIcon(provider.provider, modelConfiguration.name),
|
||||
modelName: modelConfiguration.name,
|
||||
providerId: provider.id,
|
||||
providerName: provider.name,
|
||||
provider: provider.provider,
|
||||
providerDisplayName:
|
||||
provider.provider_display_name || provider.provider,
|
||||
supportsImageInput,
|
||||
vendor: modelConfiguration.vendor || null,
|
||||
};
|
||||
@@ -108,33 +107,34 @@ export default function LLMSelector({
|
||||
requiresImageGeneration,
|
||||
]);
|
||||
|
||||
// Group options by provider using backend-provided display names
|
||||
// Group options by configured provider instance so multiple instances of the
|
||||
// same provider type (e.g., two Anthropic API keys) appear as separate groups
|
||||
// labeled with their user-given names.
|
||||
const groupedOptions = useMemo(() => {
|
||||
const groups = new Map<
|
||||
string,
|
||||
number,
|
||||
{ displayName: string; options: LLMOption[] }
|
||||
>();
|
||||
|
||||
llmOptions.forEach((option) => {
|
||||
const provider = option.provider.toLowerCase();
|
||||
if (!groups.has(provider)) {
|
||||
groups.set(provider, {
|
||||
displayName: option.providerDisplayName,
|
||||
if (!groups.has(option.providerId)) {
|
||||
groups.set(option.providerId, {
|
||||
displayName: option.providerName,
|
||||
options: [],
|
||||
});
|
||||
}
|
||||
groups.get(provider)!.options.push(option);
|
||||
groups.get(option.providerId)!.options.push(option);
|
||||
});
|
||||
|
||||
// Sort groups alphabetically by display name
|
||||
const sortedProviders = Array.from(groups.keys()).sort((a, b) =>
|
||||
const sortedProviderIds = Array.from(groups.keys()).sort((a, b) =>
|
||||
groups.get(a)!.displayName.localeCompare(groups.get(b)!.displayName)
|
||||
);
|
||||
|
||||
return sortedProviders.map((provider) => {
|
||||
const group = groups.get(provider)!;
|
||||
return sortedProviderIds.map((providerId) => {
|
||||
const group = groups.get(providerId)!;
|
||||
return {
|
||||
provider,
|
||||
providerId,
|
||||
displayName: group.displayName,
|
||||
options: group.options,
|
||||
};
|
||||
@@ -179,7 +179,7 @@ export default function LLMSelector({
|
||||
)}
|
||||
{showGrouped
|
||||
? groupedOptions.map((group) => (
|
||||
<InputSelect.Group key={group.provider}>
|
||||
<InputSelect.Group key={group.providerId}>
|
||||
<InputSelect.Label>{group.displayName}</InputSelect.Label>
|
||||
{group.options.map((option) => (
|
||||
<InputSelect.Item
|
||||
|
||||
@@ -644,6 +644,7 @@ export default function useChatController({
|
||||
});
|
||||
node.modelDisplayName = model.displayName;
|
||||
node.overridden_model = model.modelName;
|
||||
node.is_generating = true;
|
||||
return node;
|
||||
});
|
||||
}
|
||||
@@ -711,6 +712,13 @@ export default function useChatController({
|
||||
? selectedModels?.map((m) => m.displayName) ?? []
|
||||
: [];
|
||||
|
||||
// rAF-batched flush state. One Zustand write per frame instead of
|
||||
// one per packet.
|
||||
const dirtyModelIndices = new Set<number>();
|
||||
let singleModelDirty = false;
|
||||
let userNodeDirty = false;
|
||||
let pendingFlush = false;
|
||||
|
||||
/** Build a non-errored multi-model assistant node for upsert. */
|
||||
function buildAssistantNodeUpdate(
|
||||
idx: number,
|
||||
@@ -740,16 +748,124 @@ export default function useChatController({
|
||||
};
|
||||
}
|
||||
|
||||
/** Build updated nodes for all non-errored models. */
|
||||
function buildNonErroredNodes(overrides?: Partial<Message>): Message[] {
|
||||
/** With `onlyDirty`, rebuilds only those model nodes — unchanged
|
||||
* siblings keep their stable Message ref so React memo short-circuits. */
|
||||
function buildNonErroredNodes(
|
||||
overrides?: Partial<Message>,
|
||||
onlyDirty?: Set<number> | null
|
||||
): Message[] {
|
||||
const nodes: Message[] = [];
|
||||
for (let idx = 0; idx < initialAssistantNodes.length; idx++) {
|
||||
if (erroredModelIndices.has(idx)) continue;
|
||||
if (onlyDirty && !onlyDirty.has(idx)) continue;
|
||||
nodes.push(buildAssistantNodeUpdate(idx, overrides));
|
||||
}
|
||||
return nodes;
|
||||
}
|
||||
|
||||
/** Flush accumulated packet state into the tree as one Zustand
|
||||
* update. No-op when nothing is pending. */
|
||||
function flushPendingUpdates() {
|
||||
if (!pendingFlush) return;
|
||||
pendingFlush = false;
|
||||
|
||||
parentMessage =
|
||||
parentMessage || currentMessageTreeLocal?.get(SYSTEM_NODE_ID)!;
|
||||
|
||||
let messagesToUpsert: Message[];
|
||||
|
||||
if (isMultiModel) {
|
||||
if (dirtyModelIndices.size === 0 && !userNodeDirty) return;
|
||||
|
||||
const dirtySnapshot = new Set(dirtyModelIndices);
|
||||
dirtyModelIndices.clear();
|
||||
const dirtyNodes = buildNonErroredNodes(undefined, dirtySnapshot);
|
||||
|
||||
if (userNodeDirty) {
|
||||
userNodeDirty = false;
|
||||
// Read current user node to preserve childrenNodeIds
|
||||
// (initialUserNode's are stale from creation time).
|
||||
const currentUserNode =
|
||||
currentMessageTreeLocal.get(initialUserNode.nodeId) ||
|
||||
initialUserNode;
|
||||
const updatedUserNode: Message = {
|
||||
...currentUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
};
|
||||
messagesToUpsert = [updatedUserNode, ...dirtyNodes];
|
||||
} else {
|
||||
messagesToUpsert = dirtyNodes;
|
||||
}
|
||||
|
||||
if (messagesToUpsert.length === 0) return;
|
||||
} else {
|
||||
if (!singleModelDirty) return;
|
||||
singleModelDirty = false;
|
||||
|
||||
messagesToUpsert = [
|
||||
{
|
||||
...initialUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
},
|
||||
{
|
||||
...initialAgentNode,
|
||||
messageId: newAgentMessageId ?? undefined,
|
||||
message: error || answer,
|
||||
type: error ? "error" : "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents: documents,
|
||||
citations: finalMessage?.citations || citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCall: finalMessage?.tool_call || toolCall,
|
||||
stackTrace: stackTrace,
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
packets: packets,
|
||||
packetCount: packets.length,
|
||||
processingDurationSeconds:
|
||||
finalMessage?.processing_duration_seconds ??
|
||||
(() => {
|
||||
const startTime = useChatSessionStore
|
||||
.getState()
|
||||
.getStreamingStartTime(frozenSessionId);
|
||||
return startTime
|
||||
? Math.floor((Date.now() - startTime) / 1000)
|
||||
: undefined;
|
||||
})(),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: messagesToUpsert,
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
}
|
||||
|
||||
/** Awaits next animation frame (or a setTimeout fallback when the
|
||||
* tab is hidden — rAF is paused in background tabs, which would
|
||||
* otherwise hang the stream loop here), then flushes. Aligns
|
||||
* React updates with the paint cycle when visible. */
|
||||
function flushViaRAF(): Promise<void> {
|
||||
return new Promise<void>((resolve) => {
|
||||
let done = false;
|
||||
const flush = () => {
|
||||
if (done) return;
|
||||
done = true;
|
||||
flushPendingUpdates();
|
||||
resolve();
|
||||
};
|
||||
requestAnimationFrame(flush);
|
||||
// Fallback for hidden tabs where rAF is paused. Throttled to
|
||||
// ~1s by browsers, matching the previous setTimeout(500) cadence.
|
||||
setTimeout(flush, 100);
|
||||
});
|
||||
}
|
||||
|
||||
let streamSucceeded = false;
|
||||
|
||||
try {
|
||||
@@ -836,7 +952,12 @@ export default function useChatController({
|
||||
await delay(50);
|
||||
while (!stack.isComplete || !stack.isEmpty()) {
|
||||
if (stack.isEmpty()) {
|
||||
await delay(0.5);
|
||||
// Flush the burst on the next paint, or idle briefly.
|
||||
if (pendingFlush) {
|
||||
await flushViaRAF();
|
||||
} else {
|
||||
await delay(0.5);
|
||||
}
|
||||
}
|
||||
|
||||
if (!stack.isEmpty() && !controller.signal.aborted) {
|
||||
@@ -860,6 +981,7 @@ export default function useChatController({
|
||||
if ((packet as MessageResponseIDInfo).user_message_id) {
|
||||
newUserMessageId = (packet as MessageResponseIDInfo)
|
||||
.user_message_id;
|
||||
userNodeDirty = true;
|
||||
|
||||
// Track extension queries in PostHog (reuses isExtension/extensionContext from above)
|
||||
if (isExtension) {
|
||||
@@ -898,6 +1020,8 @@ export default function useChatController({
|
||||
modelDisplayNames[mi] = slot.model_name;
|
||||
}
|
||||
}
|
||||
userNodeDirty = true;
|
||||
pendingFlush = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -909,6 +1033,7 @@ export default function useChatController({
|
||||
!files.some((existingFile) => existingFile.id === newFile.id)
|
||||
);
|
||||
files = files.concat(newUserFiles);
|
||||
if (newUserFiles.length > 0) userNodeDirty = true;
|
||||
}
|
||||
|
||||
if (Object.hasOwn(packet, "file_ids")) {
|
||||
@@ -928,15 +1053,20 @@ export default function useChatController({
|
||||
|
||||
// In multi-model mode, route per-model errors to the specific model's
|
||||
// node instead of killing the entire stream. Other models keep streaming.
|
||||
if (isMultiModel && streamingError.details?.model_index != null) {
|
||||
const errorModelIndex = streamingError.details
|
||||
.model_index as number;
|
||||
if (isMultiModel) {
|
||||
// Multi-model: isolate the error to its panel. Never throw
|
||||
// or set global error state — other models keep streaming.
|
||||
const errorModelIndex = streamingError.details?.model_index as
|
||||
| number
|
||||
| undefined;
|
||||
if (
|
||||
errorModelIndex != null &&
|
||||
errorModelIndex >= 0 &&
|
||||
errorModelIndex < initialAssistantNodes.length
|
||||
) {
|
||||
const errorNode = initialAssistantNodes[errorModelIndex]!;
|
||||
erroredModelIndices.add(errorModelIndex);
|
||||
dirtyModelIndices.delete(errorModelIndex);
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: [
|
||||
{
|
||||
@@ -963,8 +1093,15 @@ export default function useChatController({
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
} else {
|
||||
// Error without model_index in multi-model — can't route
|
||||
// to a specific panel. Log and continue; the stream loop
|
||||
// stays alive for other models.
|
||||
console.warn(
|
||||
"Multi-model error without model_index:",
|
||||
streamingError.error
|
||||
);
|
||||
}
|
||||
// Skip the normal per-packet upsert — we already upserted the error node
|
||||
continue;
|
||||
} else {
|
||||
// Single-model: kill the stream
|
||||
@@ -993,19 +1130,21 @@ export default function useChatController({
|
||||
|
||||
if (isMultiModel) {
|
||||
// Multi-model: route packet by placement.model_index.
|
||||
// OverallStop (type "stop") has model_index=null — it's a global
|
||||
// terminal packet that must be delivered to ALL models so each
|
||||
// panel's AgentMessage sees the stop and exits "Thinking..." state.
|
||||
// OverallStop (type "stop") has model_index=null — it's a
|
||||
// global terminal packet that must be delivered to ALL
|
||||
// models so each panel's AgentMessage sees the stop and
|
||||
// exits "Thinking..." state.
|
||||
const isGlobalStop =
|
||||
packetObj.type === "stop" &&
|
||||
typedPacket.placement?.model_index == null;
|
||||
|
||||
if (isGlobalStop) {
|
||||
for (let mi = 0; mi < packetsPerModel.length; mi++) {
|
||||
packetsPerModel[mi] = [
|
||||
...packetsPerModel[mi]!,
|
||||
typedPacket,
|
||||
];
|
||||
// Mutated in place — change detection uses packetCount, not array identity.
|
||||
packetsPerModel[mi]!.push(typedPacket);
|
||||
if (!erroredModelIndices.has(mi)) {
|
||||
dirtyModelIndices.add(mi);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1015,10 +1154,10 @@ export default function useChatController({
|
||||
modelIndex >= 0 &&
|
||||
modelIndex < packetsPerModel.length
|
||||
) {
|
||||
packetsPerModel[modelIndex] = [
|
||||
...packetsPerModel[modelIndex]!,
|
||||
typedPacket,
|
||||
];
|
||||
packetsPerModel[modelIndex]!.push(typedPacket);
|
||||
if (!erroredModelIndices.has(modelIndex)) {
|
||||
dirtyModelIndices.add(modelIndex);
|
||||
}
|
||||
|
||||
if (packetObj.type === "citation_info") {
|
||||
const citationInfo = packetObj as {
|
||||
@@ -1048,6 +1187,7 @@ export default function useChatController({
|
||||
// Single-model
|
||||
packets.push(typedPacket);
|
||||
packetsVersion++;
|
||||
singleModelDirty = true;
|
||||
|
||||
if (packetObj.type === "citation_info") {
|
||||
const citationInfo = packetObj as {
|
||||
@@ -1074,73 +1214,16 @@ export default function useChatController({
|
||||
console.warn("Unknown packet:", JSON.stringify(packet));
|
||||
}
|
||||
|
||||
// on initial message send, we insert a dummy system message
|
||||
// set this as the parent here if no parent is set
|
||||
parentMessage =
|
||||
parentMessage || currentMessageTreeLocal?.get(SYSTEM_NODE_ID)!;
|
||||
|
||||
// Build the messages to upsert based on single vs multi-model mode
|
||||
let messagesToUpsertInLoop: Message[];
|
||||
|
||||
if (isMultiModel) {
|
||||
// Read the current user node from the tree to preserve childrenNodeIds
|
||||
// (initialUserNode has stale/empty children from creation time).
|
||||
const currentUserNode =
|
||||
currentMessageTreeLocal.get(initialUserNode.nodeId) ||
|
||||
initialUserNode;
|
||||
const updatedUserNode: Message = {
|
||||
...currentUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
};
|
||||
messagesToUpsertInLoop = [
|
||||
updatedUserNode,
|
||||
...buildNonErroredNodes(),
|
||||
];
|
||||
} else {
|
||||
messagesToUpsertInLoop = [
|
||||
{
|
||||
...initialUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
},
|
||||
{
|
||||
...initialAgentNode,
|
||||
messageId: newAgentMessageId ?? undefined,
|
||||
message: error || answer,
|
||||
type: error ? "error" : "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents: documents,
|
||||
citations: finalMessage?.citations || citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCall: finalMessage?.tool_call || toolCall,
|
||||
stackTrace: stackTrace,
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
packets: packets,
|
||||
packetCount: packets.length,
|
||||
processingDurationSeconds:
|
||||
finalMessage?.processing_duration_seconds ??
|
||||
(() => {
|
||||
const startTime = useChatSessionStore
|
||||
.getState()
|
||||
.getStreamingStartTime(frozenSessionId);
|
||||
return startTime
|
||||
? Math.floor((Date.now() - startTime) / 1000)
|
||||
: undefined;
|
||||
})(),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: messagesToUpsertInLoop,
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
// Mark dirty — flushViaRAF coalesces bursts into one React update per frame.
|
||||
if (!isMultiModel) singleModelDirty = true;
|
||||
pendingFlush = true;
|
||||
}
|
||||
}
|
||||
// Flush any tail state from the final packet(s) before declaring
|
||||
// the stream complete. Without this, the last ≤1 frame of packets
|
||||
// could get stranded in local state.
|
||||
flushPendingUpdates();
|
||||
|
||||
// Surface FIFO errors (e.g. 429 before any packets arrive) so the
|
||||
// catch block replaces the thinking placeholder with an error message.
|
||||
if (stack.error) {
|
||||
@@ -1174,6 +1257,7 @@ export default function useChatController({
|
||||
errorCode,
|
||||
isRetryable,
|
||||
errorDetails,
|
||||
is_generating: false,
|
||||
})
|
||||
: [
|
||||
{
|
||||
|
||||
@@ -106,9 +106,23 @@ export default function useMultiModelChat(
|
||||
[currentLlmModel]
|
||||
);
|
||||
|
||||
const removeModel = useCallback((index: number) => {
|
||||
setSelectedModels((prev) => prev.filter((_, i) => i !== index));
|
||||
}, []);
|
||||
const removeModel = useCallback(
|
||||
(index: number) => {
|
||||
const next = selectedModels.filter((_, i) => i !== index);
|
||||
// When dropping to single-model, switch llmManager to the surviving
|
||||
// model so it becomes the active model instead of reverting to the
|
||||
// user's default.
|
||||
if (next.length === 1 && next[0]) {
|
||||
llmManager.updateCurrentLlm({
|
||||
name: next[0].name,
|
||||
provider: next[0].provider,
|
||||
modelName: next[0].modelName,
|
||||
});
|
||||
}
|
||||
setSelectedModels(next);
|
||||
},
|
||||
[selectedModels, llmManager]
|
||||
);
|
||||
|
||||
const replaceModel = useCallback(
|
||||
(index: number, model: SelectedModel) => {
|
||||
|
||||
@@ -48,6 +48,7 @@ describe("useSettings", () => {
|
||||
anonymous_user_enabled: false,
|
||||
invite_only_enabled: false,
|
||||
deep_research_enabled: true,
|
||||
multi_model_chat_enabled: true,
|
||||
temperature_override_enabled: true,
|
||||
query_history_type: QueryHistoryType.NORMAL,
|
||||
});
|
||||
@@ -65,6 +66,7 @@ describe("useSettings", () => {
|
||||
anonymous_user_enabled: false,
|
||||
invite_only_enabled: false,
|
||||
deep_research_enabled: true,
|
||||
multi_model_chat_enabled: true,
|
||||
temperature_override_enabled: true,
|
||||
query_history_type: QueryHistoryType.NORMAL,
|
||||
};
|
||||
|
||||
@@ -23,6 +23,7 @@ const DEFAULT_SETTINGS = {
|
||||
anonymous_user_enabled: false,
|
||||
invite_only_enabled: false,
|
||||
deep_research_enabled: true,
|
||||
multi_model_chat_enabled: true,
|
||||
temperature_override_enabled: true,
|
||||
query_history_type: QueryHistoryType.NORMAL,
|
||||
} satisfies Settings;
|
||||
|
||||
134
web/src/hooks/useTypewriter.ts
Normal file
134
web/src/hooks/useTypewriter.ts
Normal file
@@ -0,0 +1,134 @@
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
|
||||
// Fixed reveal rate — NOT adaptive. Any ceil(delta/N) formula produces
|
||||
// visible chunks on burst packet arrivals. 1 = 60 cps, 2 = 120 cps.
|
||||
const CHARS_PER_FRAME = 3;
|
||||
|
||||
/**
|
||||
* Reveals `target` one character at a time on each animation frame.
|
||||
* When `enabled` is false (historical messages), snaps to full on mount.
|
||||
* The rAF loop pauses once caught up and resumes when `target` grows.
|
||||
*/
|
||||
export function useTypewriter(target: string, enabled: boolean): string {
|
||||
// Ref so the rAF loop reads latest length without restarting.
|
||||
const targetRef = useRef(target);
|
||||
targetRef.current = target;
|
||||
|
||||
// Mirror `enabled` so the restart effect can short-circuit when the
|
||||
// caller has turned animation off (e.g. voice-mode, where display is
|
||||
// driven by audio position — the typewriter must stay idle and not
|
||||
// animate a jump after audio ends).
|
||||
const enabledRef = useRef(enabled);
|
||||
enabledRef.current = enabled;
|
||||
|
||||
// `enabled` controls initial state: animate from 0 vs snap to full for
|
||||
// history/voice. Transitions mid-stream are handled via enabledRef in
|
||||
// the restart effect so a flip to false doesn't dump the buffered tail
|
||||
// *and* doesn't spin up the rAF loop on later growth.
|
||||
const [displayedLength, setDisplayedLength] = useState<number>(
|
||||
enabled ? 0 : target.length
|
||||
);
|
||||
|
||||
// Mirror displayedLength in a ref so the rAF loop can read the latest
|
||||
// value without stale-closure issues AND without needing a functional
|
||||
// state updater (which must be pure — no ref mutations inside).
|
||||
const displayedLengthRef = useRef(displayedLength);
|
||||
|
||||
// Clamp (not reset) on target shrink — preserves already-revealed chars
|
||||
// across user-cancel freeze and regeneration.
|
||||
const prevTargetLengthRef = useRef(target.length);
|
||||
useEffect(() => {
|
||||
if (target.length < prevTargetLengthRef.current) {
|
||||
const clamped = Math.min(displayedLengthRef.current, target.length);
|
||||
displayedLengthRef.current = clamped;
|
||||
setDisplayedLength(clamped);
|
||||
}
|
||||
prevTargetLengthRef.current = target.length;
|
||||
}, [target.length]);
|
||||
|
||||
// Self-scheduling rAF loop. Pauses when caught up so idle/historical
|
||||
// messages don't run a 60fps no-op updater for their entire lifetime.
|
||||
const rafIdRef = useRef<number | null>(null);
|
||||
const runningRef = useRef(false);
|
||||
const startLoopRef = useRef<(() => void) | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const tick = () => {
|
||||
const targetLen = targetRef.current.length;
|
||||
const prev = displayedLengthRef.current;
|
||||
if (prev >= targetLen) {
|
||||
// Caught up — pause the loop. The sibling effect below will
|
||||
// restart it when `target` grows.
|
||||
runningRef.current = false;
|
||||
rafIdRef.current = null;
|
||||
return;
|
||||
}
|
||||
const next = Math.min(prev + CHARS_PER_FRAME, targetLen);
|
||||
displayedLengthRef.current = next;
|
||||
setDisplayedLength(next);
|
||||
rafIdRef.current = requestAnimationFrame(tick);
|
||||
};
|
||||
|
||||
const start = () => {
|
||||
if (runningRef.current) return;
|
||||
// Animation disabled — snap to full and stay idle. This is the
|
||||
// voice-mode path where content is driven by audio position, and
|
||||
// any "gap" (e.g. user stops audio early) must jump instantly
|
||||
// instead of animating a 1500-char typewriter burst.
|
||||
if (!enabledRef.current) {
|
||||
const targetLen = targetRef.current.length;
|
||||
if (displayedLengthRef.current !== targetLen) {
|
||||
displayedLengthRef.current = targetLen;
|
||||
setDisplayedLength(targetLen);
|
||||
}
|
||||
return;
|
||||
}
|
||||
runningRef.current = true;
|
||||
rafIdRef.current = requestAnimationFrame(tick);
|
||||
};
|
||||
|
||||
startLoopRef.current = start;
|
||||
|
||||
if (targetRef.current.length > displayedLengthRef.current) {
|
||||
start();
|
||||
}
|
||||
|
||||
return () => {
|
||||
runningRef.current = false;
|
||||
if (rafIdRef.current !== null) {
|
||||
cancelAnimationFrame(rafIdRef.current);
|
||||
rafIdRef.current = null;
|
||||
}
|
||||
startLoopRef.current = null;
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Restart the loop when target grows past what's currently displayed.
|
||||
useEffect(() => {
|
||||
if (target.length > displayedLength && startLoopRef.current) {
|
||||
startLoopRef.current();
|
||||
}
|
||||
}, [target.length, displayedLength]);
|
||||
|
||||
// When the user navigates away and back (tab switch, window focus),
|
||||
// snap to all collected content so they see the full response immediately.
|
||||
useEffect(() => {
|
||||
const handleVisibility = () => {
|
||||
if (document.visibilityState === "visible") {
|
||||
const targetLen = targetRef.current.length;
|
||||
if (displayedLengthRef.current < targetLen) {
|
||||
displayedLengthRef.current = targetLen;
|
||||
setDisplayedLength(targetLen);
|
||||
}
|
||||
}
|
||||
};
|
||||
document.addEventListener("visibilitychange", handleVisibility);
|
||||
return () =>
|
||||
document.removeEventListener("visibilitychange", handleVisibility);
|
||||
}, []);
|
||||
|
||||
return useMemo(
|
||||
() => target.slice(0, Math.min(displayedLength, target.length)),
|
||||
[target, displayedLength]
|
||||
);
|
||||
}
|
||||
@@ -27,6 +27,7 @@ export interface Settings {
|
||||
query_history_type: QueryHistoryType;
|
||||
|
||||
deep_research_enabled?: boolean;
|
||||
multi_model_chat_enabled?: boolean;
|
||||
search_ui_enabled?: boolean;
|
||||
|
||||
// Image processing settings
|
||||
|
||||
@@ -173,8 +173,13 @@ function AttachmentItemLayout({
|
||||
rightChildren,
|
||||
}: AttachmentItemLayoutProps) {
|
||||
return (
|
||||
<Section flexDirection="row" gap={0.25} padding={0.25}>
|
||||
<div className={cn("h-[2.25rem] aspect-square rounded-08")}>
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="start"
|
||||
gap={0.25}
|
||||
padding={0.25}
|
||||
>
|
||||
<div className={cn("h-[2.25rem] aspect-square rounded-08 flex-shrink-0")}>
|
||||
<Section>
|
||||
<div
|
||||
className="attachment-button__icon-wrapper"
|
||||
@@ -189,6 +194,7 @@ function AttachmentItemLayout({
|
||||
justifyContent="between"
|
||||
alignItems="center"
|
||||
gap={1.5}
|
||||
className="min-w-0"
|
||||
>
|
||||
<div data-testid="attachment-item-title" className="flex-1 min-w-0">
|
||||
<Content
|
||||
|
||||
@@ -9,6 +9,7 @@ import { useField, useFormikContext } from "formik";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { Content } from "@opal/layouts";
|
||||
import Label from "@/refresh-components/form/Label";
|
||||
import type { TagProps } from "@opal/components/tag/components";
|
||||
|
||||
interface OrientationLayoutProps {
|
||||
name?: string;
|
||||
@@ -16,6 +17,8 @@ interface OrientationLayoutProps {
|
||||
nonInteractive?: boolean;
|
||||
children?: React.ReactNode;
|
||||
title: string | RichStr;
|
||||
/** Tag rendered inline beside the title (passed through to Content). */
|
||||
tag?: TagProps;
|
||||
description?: string | RichStr;
|
||||
suffix?: "optional" | (string & {});
|
||||
sizePreset?: "main-content" | "main-ui";
|
||||
@@ -128,6 +131,7 @@ function HorizontalInputLayout({
|
||||
children,
|
||||
center,
|
||||
title,
|
||||
tag,
|
||||
description,
|
||||
suffix,
|
||||
sizePreset = "main-content",
|
||||
@@ -144,6 +148,7 @@ function HorizontalInputLayout({
|
||||
title={title}
|
||||
description={description}
|
||||
suffix={suffix}
|
||||
tag={tag}
|
||||
sizePreset={sizePreset}
|
||||
variant="section"
|
||||
widthVariant="full"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user