Compare commits

..

12 Commits

Author SHA1 Message Date
Nik
c8e565fa75 fix(chat): fix B1/B2/P1 bugs in multi-model streaming + cleanup
B1 — Self-completion race: model finishes before GeneratorExit fires,
exits _run_model with drain_done=False, skips self-completion. Fix:
add completion_locks (one per model); disconnect else-branch claims
lock and calls llm_loop_completion_handle for already-succeeded models.

B2 — Stop-button saves wrong message for errored models: the stop loop
called llm_loop_completion_handle for all models including ones that
threw exceptions, persisting "stopped by user" for an errored model.
Fix: add model_errored flag; stop loop skips errored models.

P1 — Orphaned ChatMessage rows for errored models: reserved_messages
were never cleaned up when a model errored. Fix: delete via
db_session.get(ChatMessage, id) in all three exit paths (normal
completion, stop-button, disconnect).

Also: extract repeated orphan-cleanup into _delete_orphaned_message
nested helper; remove dead check_call_count variable in tests; rename
ctx→worker_context, _completion_done→completion_persisted; replace
functools.partial with captured-variable lambda; fix stale docstring
("bounded"→"unbounded"); add _CANCEL_POLL_INTERVAL_S named constant;
if/if/if→if/elif/elif; %-style logger calls throughout.

Tests: two new regression tests (B1 race, B2 stop-button errored model).
26 tests pass. mypy clean.
2026-04-01 08:17:50 -07:00
Nik
bab95d8bf0 fix(chat): remove duplicate drain_done declaration after rebase 2026-03-31 20:02:29 -07:00
Nik
eb7bc74e1b fix(chat): persist LLM response on HTTP disconnect via drain_done + worker self-completion
When the HTTP client disconnects, Starlette throws GeneratorExit into the
drain loop generator. The old code called executor.shutdown(wait=False) with
no completion handling, leaving the assistant DB message as the TERMINATED
placeholder forever (regressing test_send_message_disconnect_and_cleanup).

New design:
- drain_done (threading.Event) signals emitters to return immediately instead
  of blocking on queue.put — no retry loops, no daemon threads
- One-time queue drain in the else block releases any in-progress puts so
  workers exit within milliseconds
- Workers self-complete: after run_llm_loop returns, each worker checks
  drain_done.is_set() and, if true, opens its own DB session and calls
  llm_loop_completion_handle directly

Unit test updated to reflect the async self-completion semantics: the test
blocks the worker inside run_llm_loop until gen.close() sets drain_done,
then waits for completion_called inside the patch context (while mocks are
still active) to avoid calling the real get_session_with_current_tenant.
2026-03-31 20:02:29 -07:00
Nik
29da0aefb5 feat(chat): add multi-model parallel streaming (N=2-3 LLMs side-by-side)
Adds support for running 2-3 LLMs in parallel within a single chat turn,
with responses streamed interleaved to the frontend via the merged queue
infrastructure introduced in the preceding PR.

Backend changes
- process_message.py: restore llm_overrides param on build_chat_turn and
  _stream_chat_turn; restore is_multi branching for LLM setup, context
  window sizing, and message ID reservation; add _build_model_display_name
  and handle_multi_model_stream (public multi-model entrypoint)
- db/chat.py: add reserve_multi_model_message_ids (reserves N assistant
  message placeholders sharing the same parent), set_preferred_response
  (marks one response as the user's preferred), and extend
  translate_db_message_to_chat_message_detail with preferred_response_id
  and model_display_name fields
- chat_backend.py: route requests with llm_overrides >1 through
  handle_multi_model_stream; reject non-streaming multi-model requests with
  OnyxError; add /set-preferred-response endpoint

Tests
- test_multi_model_streaming.py: unit tests for _run_models drain loop
  (arrival-order yield, error isolation, cancellation), handle_multi_model_stream
  validation guards, and N=1 backwards-compatibility
2026-03-31 20:01:21 -07:00
Nik
6c86301c51 fix(chat): remove bounded queue and packet drops — match old behavior
Old code used queue.Queue() (unbounded, blocking put). New code introduced
queue.Queue(maxsize=100) + put(timeout=3.0) + silent drop on queue.Full —
a regression in all three callsites:

- Emitter.emit(): data packets silently dropped on queue full
- _run_model exception path: model errors silently lost
- _run_model finally (_MODEL_DONE): if dropped, drain loop hangs forever
  (models_remaining never reaches 0)

Fix: remove maxsize, remove all timeout= arguments, remove all
except queue.Full handlers. The drain_done early-return in emit() is the
correct disconnect mechanism; queue backpressure is not needed.

Also adds _completion_done: bool type annotation and fixes the queue drain
comment (no longer unblocking timed-out puts — just releasing memory).
2026-03-31 20:00:46 -07:00
Nik
631146f48f fix(chat): use model_succeeded instead of check_is_connected on self-completion
On HTTP disconnect, check_is_connected() returns False, causing
llm_loop_completion_handle to treat a completed response as
user-cancelled and append "Generation was stopped by the user."
Use lambda: model_succeeded[model_idx] (always True here) instead,
matching the cancellation path's functools.partial(bool, model_succeeded[i]).
2026-03-31 18:42:04 -07:00
Nik
f327278506 fix(chat): persist LLM response on HTTP disconnect via drain_done + worker self-completion
When the HTTP client disconnects, Starlette throws GeneratorExit into the
drain loop generator. The old else branch just called executor.shutdown(wait=False)
with no completion handling, leaving the assistant DB message as the TERMINATED
placeholder forever (regressing test_send_message_disconnect_and_cleanup).

New design:
- drain_done (threading.Event) signals emitters to return immediately instead
  of blocking on queue.put — no retry loops, no daemon threads
- One-time queue drain in the else block releases any in-progress puts so
  workers exit within milliseconds
- Workers self-complete: after run_llm_loop returns, each worker checks
  drain_done.is_set() and, if true, opens its own DB session and calls
  llm_loop_completion_handle directly
2026-03-31 18:14:50 -07:00
Nik
c7cc439862 fix(emitter): address Greptile P1/P2/P3 and Queue typing
- P1: executor.shutdown(wait=False) on early exit — don't block the
  server thread waiting for LLM workers; they will hit queue.Full
  timeouts and exit on their own (matches old run_chat_loop behavior)
- P2: wrap db_session.commit() in try/finally in build_chat_turn —
  reset processing status before propagating if commit fails, so the
  chat session isn't stuck at "processing" permanently
- P3: fix inaccurate comment "All worker threads have exited" — workers
  may still be closing their own DB sessions at that point; clarify
  that only the main-thread db_session is safe to use
- Queue[Any] → Queue[tuple[int, Packet | Exception | object]] in Emitter
2026-03-31 17:02:46 -07:00
Nik
3365a369e2 fix(review): address Greptile comments
- Add owner to bare TODO comment
- Restore placement field assertions weakened by Emitter refactor
2026-03-31 12:49:09 -07:00
Nik
470bda3fb5 refactor(chat): elegance pass on PR1 changed files
process_message.py:
- Fix `skip_clarification` field in ChatTurnSetup: inline comment inside
  the type annotation → separate `#` comment on the line above the field
- Flatten `model_tools` via list comprehension instead of manual extend loop
- `forced_tool_id` membership test: list → set comprehension (O(1) lookup)
- Trim `_run_model` inner-function docstring — private closure doesn't need
  10-line Args block
- Remove redundant inline param comments from `_stream_chat_turn` and
  `handle_stream_message_objects` where the docstring Args section already
  documents them
- Strip duplicate Args/Returns from `handle_stream_message_objects` docstring
  — it delegates entirely to `_stream_chat_turn`

emitter.py:
- Widen `merged_queue` annotation to `Queue[Any]`: Queue is invariant so
  `Queue[tuple[int, Packet]]` can't be passed a `Queue[tuple[int, Packet |
  Exception | object]]`; the emitter is a write-only producer and doesn't
  care what else lives on the queue
2026-03-31 12:16:38 -07:00
Nik
13f511e209 refactor(emitter): clean up string annotation and use model_copy
- Fix `"Queue"` forward-reference annotation → `Queue[tuple[int, Packet]]`
  (Queue is already imported, the string was unnecessary)
- Replace manual Placement field copy with `base.model_copy(update={...})`
- Remove redundant `key` variable (was just `self._model_idx`)
- Tighten docstring
2026-03-31 11:44:28 -07:00
Nik
c5e8ba1eab refactor(chat): replace bus-polling emitter with merged-queue streaming; fix 429 hang
Switch Emitter from a per-model event bus + polling thread to a single
bounded queue shared across all models.  Each emit() call puts directly onto
the queue; the drain loop in _run_models yields packets in arrival order.

Key changes
- emitter.py: remove Bus, get_default_emitter(); add Emitter(merged_queue, model_idx)
- chat_state.py: remove run_chat_loop_with_state_containers (113-line bus-poll loop)
- process_message.py: add ChatTurnSetup dataclass and build_chat_turn(); rewrite
  _stream_chat_turn + _run_models around the merged queue; single-model (N=1)
  path is fully backwards-compatible
- placement.py, override_models.py: add docstrings; LLMOverride gains display_name
- research_agent.py, custom_tool.py: update Emitter call sites
- test_emitter.py: new unit tests for queue routing, model_index tagging, placement

Frontend 429 fix
- lib.tsx: parse response body for human-readable detail on non-2xx responses
  instead of "HTTP error! status: 429"
- useChatController.ts: surface stack.error after the FIFO drain loop exits so
  the catch block replaces the thinking placeholder with an error message
2026-03-30 22:18:48 -07:00
206 changed files with 3806 additions and 9911 deletions

View File

@@ -704,9 +704,6 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
SENTRY_RELEASE=${{ github.sha }}
secrets: |
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
@@ -789,9 +786,6 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
SENTRY_RELEASE=${{ github.sha }}
secrets: |
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest

View File

@@ -41,7 +41,7 @@ jobs:
version: v3.19.0
- name: Set up chart-testing
uses: helm/chart-testing-action@2e2940618cb426dce2999631d543b53cdcfc8527
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
with:
uv_version: "0.9.9"

View File

@@ -1,54 +0,0 @@
"""csv to tabular chat file type
Revision ID: 8188861f4e92
Revises: d8cdfee5df80
Create Date: 2026-03-31 19:23:05.753184
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "8188861f4e92"
down_revision = "d8cdfee5df80"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE chat_message
SET files = (
SELECT jsonb_agg(
CASE
WHEN elem->>'type' = 'csv'
THEN jsonb_set(elem, '{type}', '"tabular"')
ELSE elem
END
)
FROM jsonb_array_elements(files) AS elem
)
WHERE files::text LIKE '%"type": "csv"%'
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE chat_message
SET files = (
SELECT jsonb_agg(
CASE
WHEN elem->>'type' = 'tabular'
THEN jsonb_set(elem, '{type}', '"csv"')
ELSE elem
END
)
FROM jsonb_array_elements(files) AS elem
)
WHERE files::text LIKE '%"type": "tabular"%'
"""
)

View File

@@ -1,55 +0,0 @@
"""add skipped to userfilestatus
Revision ID: d8cdfee5df80
Revises: 1d78c0ca7853
Create Date: 2026-04-01 10:47:12.593950
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d8cdfee5df80"
down_revision = "1d78c0ca7853"
branch_labels = None
depends_on = None
TABLE = "user_file"
COLUMN = "status"
CONSTRAINT_NAME = "ck_user_file_status"
OLD_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
NEW_VALUES = (
"PROCESSING",
"INDEXING",
"COMPLETED",
"SKIPPED",
"FAILED",
"CANCELED",
"DELETING",
)
def _drop_status_check_constraint() -> None:
inspector = sa.inspect(op.get_bind())
for constraint in inspector.get_check_constraints(TABLE):
if COLUMN in constraint.get("sqltext", ""):
constraint_name = constraint["name"]
if constraint_name is not None:
op.drop_constraint(constraint_name, TABLE, type_="check")
def upgrade() -> None:
_drop_status_check_constraint()
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
def downgrade() -> None:
op.execute(f"UPDATE {TABLE} SET {COLUMN} = 'COMPLETED' WHERE {COLUMN} = 'SKIPPED'")
_drop_status_check_constraint()
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")

View File

@@ -5,7 +5,6 @@ from onyx.background.celery.apps.primary import celery_app
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.hooks",
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cloud",

View File

@@ -55,15 +55,6 @@ ee_tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
ee_tasks_to_schedule = [
{
"name": "hook-execution-log-cleanup",
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
"schedule": timedelta(days=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,

View File

@@ -13,7 +13,6 @@ from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.provisioning import setup_tenant
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
@@ -30,10 +29,9 @@ from shared_configs.configs import TENANT_ID_PREFIX
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
_MAX_TENANTS_PER_RUN = 5
# Time limits sized for worst-case: provisioning up to _MAX_TENANTS_PER_RUN new tenants
# (~90s each) plus migrating up to TARGET_AVAILABLE_TENANTS pool tenants (~90s each).
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 20 # 20 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 25 # 25 minutes
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
@shared_task(
@@ -93,7 +91,8 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
if batch_size < tenants_to_provision:
task_logger.info(
f"Capping batch to {batch_size} (need {tenants_to_provision}, will catch up next cycle)"
f"Capping batch to {batch_size} "
f"(need {tenants_to_provision}, will catch up next cycle)"
)
provisioned = 0
@@ -104,14 +103,12 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
provisioned += 1
except Exception:
task_logger.exception(
f"Failed to provision tenant {i + 1}/{batch_size}, continuing with remaining tenants"
f"Failed to provision tenant {i + 1}/{batch_size}, "
"continuing with remaining tenants"
)
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
# Migrate any pool tenants that were provisioned before a new migration was deployed
_migrate_stale_pool_tenants()
except Exception:
task_logger.exception("Error in check_available_tenants task")
@@ -124,46 +121,6 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
)
def _migrate_stale_pool_tenants() -> None:
"""
Run alembic upgrade head on all pool tenants. Since alembic upgrade head is
idempotent, tenants already at head are a fast no-op. This ensures pool
tenants are always current so that signup doesn't hit schema mismatches
(e.g. missing columns added after the tenant was pre-provisioned).
"""
with get_session_with_shared_schema() as db_session:
pool_tenants = db_session.query(AvailableTenant).all()
tenant_ids = [t.tenant_id for t in pool_tenants]
if not tenant_ids:
return
task_logger.info(
f"Checking {len(tenant_ids)} pool tenant(s) for pending migrations"
)
for tenant_id in tenant_ids:
try:
run_alembic_migrations(tenant_id)
new_version = get_current_alembic_version(tenant_id)
with get_session_with_shared_schema() as db_session:
tenant = (
db_session.query(AvailableTenant)
.filter_by(tenant_id=tenant_id)
.first()
)
if tenant and tenant.alembic_version != new_version:
task_logger.info(
f"Migrated pool tenant {tenant_id}: {tenant.alembic_version} -> {new_version}"
)
tenant.alembic_version = new_version
db_session.commit()
except Exception:
task_logger.exception(
f"Failed to migrate pool tenant {tenant_id}, skipping"
)
def pre_provision_tenant() -> bool:
"""
Pre-provision a new tenant and store it in the NewAvailableTenant table.

View File

@@ -69,7 +69,5 @@ EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset(
"/admin/token-rate-limits",
# Evals
"/evals",
# Hook extensions
"/admin/hooks",
}
)

View File

@@ -1,385 +0,0 @@
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
Usage (Celery tasks and FastAPI handlers):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
response_type=QueryProcessingResponse,
)
if isinstance(result, HookSkipped):
# no active hook configured — continue with original behavior
...
elif isinstance(result, HookSoftFailed):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is a validated Pydantic model instance (response_type)
...
is_reachable update policy
--------------------------
``is_reachable`` on the Hook row is updated selectively — only when the outcome
carries meaningful signal about physical reachability:
NetworkError (DNS, connection refused) → False (cannot reach the server)
HTTP 401 / 403 → False (api_key revoked or invalid)
TimeoutException → None (server may be slow, skip write)
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
Unknown exception → None (no signal, skip write)
Non-JSON / non-dict response → None (server responded, skip write)
Success (2xx, valid dict) → True (confirmed reachable)
None means "leave the current value unchanged" — no DB round-trip is made.
DB session design
-----------------
The executor uses three sessions:
1. Caller's session (db_session) — used only for the hook lookup read. All
needed fields are extracted from the Hook object before the HTTP call, so
the caller's session is not held open during the external HTTP request.
2. Log session — a separate short-lived session opened after the HTTP call
completes to write the HookExecutionLog row on failure. Success runs are
not recorded. Committed independently of everything else.
3. Reachable session — a second short-lived session to update is_reachable on
the Hook. Kept separate from the log session so a concurrent hook deletion
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
prevent the execution log from being written. This update is best-effort.
"""
import json
import time
from typing import Any
from typing import TypeVar
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.hook import create_hook_execution_log__no_commit
from onyx.db.hook import get_non_deleted_hook_by_hook_point
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
T = TypeVar("T", bound=BaseModel)
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
class _HttpOutcome(BaseModel):
"""Structured result of an HTTP hook call, returned by _process_response."""
is_success: bool
updated_is_reachable: (
bool | None
) # True/False = write to DB, None = unchanged (skip write)
status_code: int | None
error_message: str | None
response_payload: dict[str, Any] | None
def _lookup_hook(
db_session: Session,
hook_point: HookPoint,
) -> Hook | HookSkipped:
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
No HTTP call is made and no DB writes are performed for any HookSkipped path.
There is nothing to log and no reachability information to update.
"""
if MULTI_TENANT:
return HookSkipped()
hook = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if hook is None or not hook.is_active:
return HookSkipped()
if not hook.endpoint_url:
return HookSkipped()
return hook
def _process_response(
*,
response: httpx.Response | None,
exc: Exception | None,
timeout: float,
) -> _HttpOutcome:
"""Process the result of an HTTP call and return a structured outcome.
Called after the client.post() try/except. If post() raised, exc is set and
response is None. Otherwise response is set and exc is None. Handles
raise_for_status(), JSON decoding, and the dict shape check.
"""
if exc is not None:
if isinstance(exc, httpx.NetworkError):
msg = f"Hook network error (endpoint unreachable): {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False,
status_code=None,
error_message=msg,
response_payload=None,
)
if isinstance(exc, httpx.TimeoutException):
msg = f"Hook timed out after {timeout}s: {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # timeout doesn't indicate unreachability
status_code=None,
error_message=msg,
response_payload=None,
)
msg = f"Hook call failed: {exc}"
logger.exception(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # unknown error — don't make assumptions
status_code=None,
error_message=msg,
response_payload=None,
)
if response is None:
raise ValueError(
"exactly one of response or exc must be non-None; both are None"
)
status_code = response.status_code
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
logger.warning(msg, exc_info=e)
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
# so the operator knows to update it. All other HTTP errors keep is_reachable
# as-is (server is up, the request just failed for application reasons).
auth_failed = e.response.status_code in (401, 403)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False if auth_failed else None,
status_code=status_code,
error_message=msg,
response_payload=None,
)
try:
response_payload = response.json()
except (json.JSONDecodeError, httpx.DecodingError) as e:
msg = f"Hook returned non-JSON response: {e}"
logger.warning(msg, exc_info=e)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
if not isinstance(response_payload, dict):
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
logger.warning(msg)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
return _HttpOutcome(
is_success=True,
updated_is_reachable=True,
status_code=status_code,
error_message=None,
response_payload=response_payload,
)
def _persist_result(
*,
hook_id: int,
outcome: _HttpOutcome,
duration_ms: int,
) -> None:
"""Write the execution log on failure and optionally update is_reachable, each
in its own session so a failure in one does not affect the other."""
# Only write the execution log on failure — success runs are not recorded.
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
# deleted between the initial lookup and here).
if not outcome.is_success:
try:
with get_session_with_current_tenant() as log_session:
create_hook_execution_log__no_commit(
db_session=log_session,
hook_id=hook_id,
is_success=False,
error_message=outcome.error_message,
status_code=outcome.status_code,
duration_ms=duration_ms,
)
log_session.commit()
except Exception:
logger.exception(
f"Failed to persist hook execution log for hook_id={hook_id}"
)
# Update is_reachable separately — best-effort, non-critical.
# None means the value is unchanged (set by the caller to skip the no-op write).
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
# concurrently deleted, so keep this isolated from the log write above.
if outcome.updated_is_reachable is not None:
try:
with get_session_with_current_tenant() as reachable_session:
update_hook__no_commit(
db_session=reachable_session,
hook_id=hook_id,
is_reachable=outcome.updated_is_reachable,
)
reachable_session.commit()
except Exception:
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def _execute_hook_inner(
hook: Hook,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
"active hooks without an endpoint_url must be rejected by _lookup_hook"
)
start = time.monotonic()
response: httpx.Response | None = None
exc: Exception | None = None
try:
api_key: str | None = (
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
)
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against response_type.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: T | None = None
if outcome.is_success and outcome.response_payload is not None:
try:
validated_model = response_type.model_validate(outcome.response_payload)
except ValidationError as e:
msg = (
f"Hook response failed validation against {response_type.__name__}: {e}"
)
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
outcome = outcome.model_copy(update={"updated_is_reachable": None})
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
if not outcome.is_success:
if fail_strategy == HookFailStrategy.HARD:
raise OnyxError(
OnyxErrorCode.HOOK_EXECUTION_FAILED,
outcome.error_message or "Hook execution failed.",
)
logger.warning(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if validated_model is None:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"validated_model is None for successful hook call (hook_id={hook_id})",
)
return validated_model
def _execute_hook_impl(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSkipped | HookSoftFailed:
"""EE implementation — loaded by CE's execute_hook via fetch_versioned_implementation.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
"""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
fail_strategy = hook.fail_strategy
hook_id = hook.id
try:
return _execute_hook_inner(hook, payload, response_type)
except Exception:
if fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
)
return HookSoftFailed()
raise

View File

@@ -15,7 +15,6 @@ from ee.onyx.server.enterprise_settings.api import (
basic_router as enterprise_settings_router,
)
from ee.onyx.server.evals.api import router as evals_router
from ee.onyx.server.features.hooks.api import router as hook_router
from ee.onyx.server.license.api import router as license_router
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.license_enforcement import (
@@ -139,7 +138,6 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, ee_oauth_router)
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
include_router_with_global_prefix_prepended(application, evals_router)
include_router_with_global_prefix_prepended(application, hook_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(

View File

@@ -99,26 +99,6 @@ async def get_or_provision_tenant(
tenant_id = await get_available_tenant()
if tenant_id:
# Run migrations to ensure the pre-provisioned tenant schema is current.
# Pool tenants may have been created before a new migration was deployed.
# Capture as a non-optional local so mypy can type the lambda correctly.
_tenant_id: str = tenant_id
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(
None, lambda: run_alembic_migrations(_tenant_id)
)
except Exception:
# The tenant was already dequeued from the pool — roll it back so
# it doesn't end up orphaned (schema exists, but not assigned to anyone).
logger.exception(
f"Migration failed for pre-provisioned tenant {_tenant_id}; rolling back"
)
try:
await rollback_tenant_provisioning(_tenant_id)
except Exception:
logger.exception(f"Failed to rollback orphaned tenant {_tenant_id}")
raise
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")

View File

@@ -100,7 +100,6 @@ def get_model_app() -> FastAPI:
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -20,7 +20,6 @@ from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx import __version__
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
@@ -66,7 +65,6 @@ if SENTRY_DSN:
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:
@@ -517,8 +515,7 @@ def reset_tenant_id(
def wait_for_vespa_or_shutdown(
sender: Any, # noqa: ARG001
**kwargs: Any, # noqa: ARG001
sender: Any, **kwargs: Any # noqa: ARG001
) -> None: # noqa: ARG001
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""

View File

@@ -317,6 +317,7 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.hierarchyfetching",
"onyx.background.celery.tasks.hooks",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",

View File

@@ -14,6 +14,7 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.hooks.utils import HOOKS_AVAILABLE
from shared_configs.configs import MULTI_TENANT
# choosing 15 minutes because it roughly gives us enough time to process many tasks
@@ -361,6 +362,19 @@ if not MULTI_TENANT:
tasks_to_schedule.extend(beat_task_templates)
if HOOKS_AVAILABLE:
tasks_to_schedule.append(
{
"name": "hook-execution-log-cleanup",
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
"schedule": timedelta(days=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def generate_cloud_tasks(
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float

View File

@@ -9,7 +9,6 @@ from celery import Celery
from celery import shared_task
from celery import Task
from onyx import __version__
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
@@ -138,7 +137,6 @@ def _docfetching_task(
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -319,11 +319,6 @@ def monitor_indexing_attempt_progress(
)
current_db_time = get_db_current_time(db_session)
total_batches: int | str = (
coordination_status.total_batches
if coordination_status.total_batches is not None
else "?"
)
if coordination_status.found:
task_logger.info(
f"Indexing attempt progress: "
@@ -331,7 +326,7 @@ def monitor_indexing_attempt_progress(
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id} "
f"completed_batches={coordination_status.completed_batches} "
f"total_batches={total_batches} "
f"total_batches={coordination_status.total_batches or '?'} "
f"total_docs={coordination_status.total_docs} "
f"total_failures={coordination_status.total_failures}"
f"elapsed={(current_db_time - attempt.time_created).seconds}"
@@ -415,7 +410,7 @@ def check_indexing_completion(
logger.info(
f"Indexing status: "
f"indexing_completed={indexing_completed} "
f"batches_processed={batches_processed}/{batches_total if batches_total is not None else '?'} "
f"batches_processed={batches_processed}/{batches_total or '?'} "
f"total_docs={coordination_status.total_docs} "
f"total_chunks={coordination_status.total_chunks} "
f"total_failures={coordination_status.total_failures}"

View File

@@ -5,7 +5,6 @@ from typing import cast
from uuid import UUID
from fastapi.datastructures import Headers
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.chat.models import ChatHistoryResult
@@ -52,60 +51,6 @@ logger = setup_logger()
IMAGE_GENERATION_TOOL_NAME = "generate_image"
class FileContextResult(BaseModel):
"""Result of building a file's LLM context representation."""
message: ChatMessageSimple
tool_metadata: FileToolMetadata
def build_file_context(
tool_file_id: str,
filename: str,
file_type: ChatFileType,
content_text: str | None = None,
token_count: int = 0,
approx_char_count: int | None = None,
) -> FileContextResult:
"""Build the LLM context representation for a single file.
Centralises how files should appear in the LLM prompt
— the ID that FileReaderTool accepts (``UserFile.id`` for user files).
"""
if file_type.use_metadata_only():
message_text = (
f"File: {filename} (id={tool_file_id})\n"
"Use the file_reader or python tools to access "
"this file's contents."
)
message = ChatMessageSimple(
message=message_text,
token_count=max(1, len(message_text) // 4),
message_type=MessageType.USER,
file_id=tool_file_id,
)
else:
message_text = f"File: {filename}\n{content_text or ''}\nEnd of File"
message = ChatMessageSimple(
message=message_text,
token_count=token_count,
message_type=MessageType.USER,
file_id=tool_file_id,
)
metadata = FileToolMetadata(
file_id=tool_file_id,
filename=filename,
approx_char_count=(
approx_char_count
if approx_char_count is not None
else len(content_text or "")
),
)
return FileContextResult(message=message, tool_metadata=metadata)
def create_chat_session_from_request(
chat_session_request: ChatSessionCreationRequest,
user_id: UUID | None,
@@ -593,7 +538,7 @@ def convert_chat_history(
for idx, chat_message in enumerate(chat_history):
if chat_message.message_type == MessageType.USER:
# Process files attached to this message
text_files: list[tuple[ChatLoadedFile, FileDescriptor]] = []
text_files: list[ChatLoadedFile] = []
image_files: list[ChatLoadedFile] = []
if chat_message.files:
@@ -604,26 +549,34 @@ def convert_chat_history(
if loaded_file.file_type == ChatFileType.IMAGE:
image_files.append(loaded_file)
else:
# Text files (DOC, PLAIN_TEXT, TABULAR) are added as separate messages
text_files.append((loaded_file, file_descriptor))
# Text files (DOC, PLAIN_TEXT, CSV) are added as separate messages
text_files.append(loaded_file)
# Add text files as separate messages before the user message.
# Each message is tagged with ``file_id`` so that forgotten files
# can be detected after context-window truncation.
for text_file, fd in text_files:
# Use user_file_id as the FileReaderTool accepts that.
# Fall back to the file-store path id.
tool_id = fd.get("user_file_id") or text_file.file_id
filename = text_file.filename or "unknown"
ctx = build_file_context(
tool_file_id=tool_id,
filename=filename,
file_type=text_file.file_type,
content_text=text_file.content_text,
token_count=text_file.token_count,
for text_file in text_files:
file_text = text_file.content_text or ""
filename = text_file.filename
message = (
f"File: {filename}\n{file_text}\nEnd of File"
if filename
else file_text
)
simple_messages.append(
ChatMessageSimple(
message=message,
token_count=text_file.token_count,
message_type=MessageType.USER,
image_files=None,
file_id=text_file.file_id,
)
)
all_injected_file_metadata[text_file.file_id] = FileToolMetadata(
file_id=text_file.file_id,
filename=filename or "unknown",
approx_char_count=len(file_text),
)
simple_messages.append(ctx.message)
all_injected_file_metadata[tool_id] = ctx.tool_metadata
# Sum token counts from image files (excluding project image files)
image_token_count = (

View File

@@ -24,7 +24,6 @@ from onyx.cache.factory import get_cache_backend
from onyx.cache.interface import CacheBackend
from onyx.chat.chat_processing_checker import set_processing_status
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_utils import build_file_context
from onyx.chat.chat_utils import convert_chat_history
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.chat.chat_utils import create_chat_session_from_request
@@ -102,7 +101,6 @@ from onyx.llm.request_context import reset_llm_mock_response
from onyx.llm.request_context import set_llm_mock_response
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import ModelResponseSlot
@@ -315,27 +313,16 @@ def extract_context_files(
if not user_files:
return _empty_extracted_context_files()
# Aggregate tokens for the file content that will be added
# Skip tokens for those with metadata only
aggregate_tokens = sum(
uf.token_count or 0
for uf in user_files
if not mime_type_to_chat_file_type(uf.file_type).use_metadata_only()
)
aggregate_tokens = sum(uf.token_count or 0 for uf in user_files)
max_actual_tokens = (
llm_max_context_window - reserved_token_count
) * max_llm_context_percentage
if aggregate_tokens >= max_actual_tokens:
tool_metadata = []
use_as_search_filter = not DISABLE_VECTOR_DB
if DISABLE_VECTOR_DB:
overflow_tool_metadata = [_build_tool_metadata(uf) for uf in user_files]
else:
overflow_tool_metadata = [
_build_tool_metadata(uf)
for uf in user_files
if mime_type_to_chat_file_type(uf.file_type).use_metadata_only()
]
tool_metadata = _build_file_tool_metadata_for_user_files(user_files)
return ExtractedContextFiles(
file_texts=[],
image_files=[],
@@ -343,11 +330,11 @@ def extract_context_files(
total_token_count=0,
file_metadata=[],
uncapped_token_count=aggregate_tokens,
file_metadata_for_tool=overflow_tool_metadata,
file_metadata_for_tool=tool_metadata,
)
# Files fit — load them into context
user_file_map = {uf.file_id: uf for uf in user_files}
user_file_map = {str(uf.id): uf for uf in user_files}
in_memory_files = load_in_memory_chat_files(
user_file_ids=[uf.id for uf in user_files],
db_session=db_session,
@@ -356,38 +343,23 @@ def extract_context_files(
file_texts: list[str] = []
image_files: list[ChatLoadedFile] = []
file_metadata: list[ContextFileMetadata] = []
tool_metadata: list[FileToolMetadata] = []
total_token_count = 0
for f in in_memory_files:
uf = user_file_map.get(str(f.file_id))
filename = f.filename or f"file_{f.file_id}"
if f.file_type.use_metadata_only():
# Metadata-only files are not injected as full text.
# Only the metadata is provided, with LLM using tools
if not uf:
logger.error(
f"File with id={f.file_id} in metadata-only path with no associated user file"
)
continue
tool_metadata.append(_build_tool_metadata(uf))
elif f.file_type.is_text_file():
if f.file_type.is_text_file():
text_content = _extract_text_from_in_memory_file(f)
if not text_content:
continue
if not uf:
logger.warning(f"No user file for file_id={f.file_id}")
continue
file_texts.append(text_content)
file_metadata.append(
ContextFileMetadata(
file_id=str(uf.id),
filename=filename,
file_id=str(f.file_id),
filename=f.filename or f"file_{f.file_id}",
file_content=text_content,
)
)
if uf.token_count:
if uf and uf.token_count:
total_token_count += uf.token_count
elif f.file_type == ChatFileType.IMAGE:
token_count = uf.token_count if uf and uf.token_count else 0
@@ -410,25 +382,24 @@ def extract_context_files(
total_token_count=total_token_count,
file_metadata=file_metadata,
uncapped_token_count=aggregate_tokens,
file_metadata_for_tool=tool_metadata,
)
APPROX_CHARS_PER_TOKEN = 4
def _build_tool_metadata(user_file: UserFile) -> FileToolMetadata:
"""Build lightweight FileToolMetadata from a UserFile record.
Delegates to ``build_file_context`` so that the file ID exposed to the
LLM is always consistent with what FileReaderTool expects.
"""
return build_file_context(
tool_file_id=str(user_file.id),
filename=user_file.name,
file_type=mime_type_to_chat_file_type(user_file.file_type),
approx_char_count=(user_file.token_count or 0) * APPROX_CHARS_PER_TOKEN,
).tool_metadata
def _build_file_tool_metadata_for_user_files(
user_files: list[UserFile],
) -> list[FileToolMetadata]:
"""Build lightweight FileToolMetadata from a list of UserFile records."""
return [
FileToolMetadata(
file_id=str(uf.id),
filename=uf.name,
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
)
for uf in user_files
]
def determine_search_params(

View File

@@ -1079,6 +1079,7 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
HOOK_ENABLED = os.environ.get("HOOK_ENABLED", "").lower() == "true"
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"

View File

@@ -212,7 +212,6 @@ class DocumentSource(str, Enum):
PRODUCTBOARD = "productboard"
FILE = "file"
CODA = "coda"
CANVAS = "canvas"
NOTION = "notion"
ZULIP = "zulip"
LINEAR = "linear"
@@ -673,7 +672,6 @@ DocumentSourceDescription: dict[DocumentSource, str] = {
DocumentSource.SLAB: "slab data",
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
DocumentSource.FILE: "files",
DocumentSource.CANVAS: "canvas lms - courses, pages, assignments, and announcements",
DocumentSource.CODA: "coda - team workspace with docs, tables, and pages",
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
project management, and collaboration tools into a single, customizable platform",

View File

@@ -1,32 +0,0 @@
"""
Permissioning / AccessControl logic for Canvas courses.
CE stub — returns None (no permissions). The EE implementation is loaded
at runtime via ``fetch_versioned_implementation``.
"""
from collections.abc import Callable
from typing import cast
from onyx.access.models import ExternalAccess
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version
def get_course_permissions(
canvas_client: CanvasApiClient,
course_id: int,
) -> ExternalAccess | None:
if not global_version.is_ee_version():
return None
ee_get_course_permissions = cast(
Callable[[CanvasApiClient, int], ExternalAccess | None],
fetch_versioned_implementation(
"onyx.external_permissions.canvas.access",
"get_course_permissions",
),
)
return ee_get_course_permissions(canvas_client, course_id)

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import logging
import re
from collections.abc import Iterator
from typing import Any
from urllib.parse import urlparse
@@ -191,22 +190,3 @@ class CanvasApiClient:
if clean_endpoint:
final_url += "/" + clean_endpoint
return final_url
def paginate(
self,
endpoint: str,
params: dict[str, Any] | None = None,
) -> Iterator[list[Any]]:
"""Yield each page of results, following Link-header pagination.
Makes the first request with endpoint + params, then follows
next_url from Link headers for subsequent pages.
"""
response, next_url = self.get(endpoint, params=params)
while True:
if not response:
break
yield response
if not next_url:
break
response, next_url = self.get(full_url=next_url)

View File

@@ -1,82 +1,17 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Literal
from typing import NoReturn
from typing import TypeAlias
from pydantic import BaseModel
from retry import retry
from typing_extensions import override
from onyx.access.models import ExternalAccess
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.canvas.access import get_course_permissions
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
"""Map Canvas API errors to connector framework exceptions."""
if e.status_code == 401:
raise CredentialExpiredError(
"Canvas API token is invalid or expired (HTTP 401)."
)
elif e.status_code == 403:
raise InsufficientPermissionsError(
"Canvas API token does not have sufficient permissions (HTTP 403)."
)
elif e.status_code == 429:
raise ConnectorValidationError(
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
)
elif e.status_code >= 500:
raise UnexpectedValidationError(
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
)
else:
raise ConnectorValidationError(
f"Canvas API error (status={e.status_code}): {e}"
)
class CanvasCourse(BaseModel):
id: int
name: str | None = None
course_code: str | None = None
created_at: str | None = None
workflow_state: str | None = None
@classmethod
def from_api(cls, payload: dict[str, Any]) -> "CanvasCourse":
return cls(
id=payload["id"],
name=payload.get("name"),
course_code=payload.get("course_code"),
created_at=payload.get("created_at"),
workflow_state=payload.get("workflow_state"),
)
name: str
course_code: str
created_at: str
workflow_state: str
class CanvasPage(BaseModel):
@@ -84,22 +19,10 @@ class CanvasPage(BaseModel):
url: str
title: str
body: str | None = None
created_at: str | None = None
updated_at: str | None = None
created_at: str
updated_at: str
course_id: int
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasPage":
return cls(
page_id=payload["page_id"],
url=payload["url"],
title=payload["title"],
body=payload.get("body"),
created_at=payload.get("created_at"),
updated_at=payload.get("updated_at"),
course_id=course_id,
)
class CanvasAssignment(BaseModel):
id: int
@@ -107,23 +30,10 @@ class CanvasAssignment(BaseModel):
description: str | None = None
html_url: str
course_id: int
created_at: str | None = None
updated_at: str | None = None
created_at: str
updated_at: str
due_at: str | None = None
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAssignment":
return cls(
id=payload["id"],
name=payload["name"],
description=payload.get("description"),
html_url=payload["html_url"],
course_id=course_id,
created_at=payload.get("created_at"),
updated_at=payload.get("updated_at"),
due_at=payload.get("due_at"),
)
class CanvasAnnouncement(BaseModel):
id: int
@@ -133,17 +43,6 @@ class CanvasAnnouncement(BaseModel):
posted_at: str | None = None
course_id: int
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAnnouncement":
return cls(
id=payload["id"],
title=payload["title"],
message=payload.get("message"),
html_url=payload["html_url"],
posted_at=payload.get("posted_at"),
course_id=course_id,
)
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
@@ -173,286 +72,3 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
self.current_course_index += 1
self.stage = "pages"
self.next_url = None
class CanvasConnector(
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
SlimConnectorWithPermSync,
):
def __init__(
self,
canvas_base_url: str,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.canvas_base_url = canvas_base_url.rstrip("/").removesuffix("/api/v1")
self.batch_size = batch_size
self._canvas_client: CanvasApiClient | None = None
self._course_permissions_cache: dict[int, ExternalAccess | None] = {}
@property
def canvas_client(self) -> CanvasApiClient:
if self._canvas_client is None:
raise ConnectorMissingCredentialError("Canvas")
return self._canvas_client
def _get_course_permissions(self, course_id: int) -> ExternalAccess | None:
"""Get course permissions with caching."""
if course_id not in self._course_permissions_cache:
self._course_permissions_cache[course_id] = get_course_permissions(
canvas_client=self.canvas_client,
course_id=course_id,
)
return self._course_permissions_cache[course_id]
@retry(tries=3, delay=1, backoff=2)
def _list_courses(self) -> list[CanvasCourse]:
"""Fetch all courses accessible to the authenticated user."""
logger.debug("Fetching Canvas courses")
courses: list[CanvasCourse] = []
for page in self.canvas_client.paginate(
"courses", params={"per_page": "100", "state[]": "available"}
):
courses.extend(CanvasCourse.from_api(c) for c in page)
return courses
@retry(tries=3, delay=1, backoff=2)
def _list_pages(self, course_id: int) -> list[CanvasPage]:
"""Fetch all pages for a given course."""
logger.debug(f"Fetching pages for course {course_id}")
pages: list[CanvasPage] = []
for page in self.canvas_client.paginate(
f"courses/{course_id}/pages",
params={"per_page": "100", "include[]": "body", "published": "true"},
):
pages.extend(CanvasPage.from_api(p, course_id=course_id) for p in page)
return pages
@retry(tries=3, delay=1, backoff=2)
def _list_assignments(self, course_id: int) -> list[CanvasAssignment]:
"""Fetch all assignments for a given course."""
logger.debug(f"Fetching assignments for course {course_id}")
assignments: list[CanvasAssignment] = []
for page in self.canvas_client.paginate(
f"courses/{course_id}/assignments",
params={"per_page": "100", "published": "true"},
):
assignments.extend(
CanvasAssignment.from_api(a, course_id=course_id) for a in page
)
return assignments
@retry(tries=3, delay=1, backoff=2)
def _list_announcements(self, course_id: int) -> list[CanvasAnnouncement]:
"""Fetch all announcements for a given course."""
logger.debug(f"Fetching announcements for course {course_id}")
announcements: list[CanvasAnnouncement] = []
for page in self.canvas_client.paginate(
"announcements",
params={
"per_page": "100",
"context_codes[]": f"course_{course_id}",
"active_only": "true",
},
):
announcements.extend(
CanvasAnnouncement.from_api(a, course_id=course_id) for a in page
)
return announcements
def _build_document(
self,
doc_id: str,
link: str,
text: str,
semantic_identifier: str,
doc_updated_at: datetime | None,
course_id: int,
doc_type: str,
) -> Document:
"""Build a Document with standard Canvas fields."""
return Document(
id=doc_id,
sections=cast(
list[TextSection | ImageSection],
[TextSection(link=link, text=text)],
),
source=DocumentSource.CANVAS,
semantic_identifier=semantic_identifier,
doc_updated_at=doc_updated_at,
metadata={"course_id": str(course_id), "type": doc_type},
)
def _convert_page_to_document(self, page: CanvasPage) -> Document:
"""Convert a Canvas page to a Document."""
link = f"{self.canvas_base_url}/courses/{page.course_id}/pages/{page.url}"
text_parts = [page.title]
body_text = parse_html_page_basic(page.body) if page.body else ""
if body_text:
text_parts.append(body_text)
doc_updated_at = (
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
timezone.utc
)
if page.updated_at
else None
)
document = self._build_document(
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
link=link,
text="\n\n".join(text_parts),
semantic_identifier=page.title or f"Page {page.page_id}",
doc_updated_at=doc_updated_at,
course_id=page.course_id,
doc_type="page",
)
return document
def _convert_assignment_to_document(self, assignment: CanvasAssignment) -> Document:
"""Convert a Canvas assignment to a Document."""
text_parts = [assignment.name]
desc_text = (
parse_html_page_basic(assignment.description)
if assignment.description
else ""
)
if desc_text:
text_parts.append(desc_text)
if assignment.due_at:
due_dt = datetime.fromisoformat(
assignment.due_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
doc_updated_at = (
datetime.fromisoformat(
assignment.updated_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if assignment.updated_at
else None
)
document = self._build_document(
doc_id=f"canvas-assignment-{assignment.course_id}-{assignment.id}",
link=assignment.html_url,
text="\n\n".join(text_parts),
semantic_identifier=assignment.name or f"Assignment {assignment.id}",
doc_updated_at=doc_updated_at,
course_id=assignment.course_id,
doc_type="assignment",
)
return document
def _convert_announcement_to_document(
self, announcement: CanvasAnnouncement
) -> Document:
"""Convert a Canvas announcement to a Document."""
text_parts = [announcement.title]
msg_text = (
parse_html_page_basic(announcement.message) if announcement.message else ""
)
if msg_text:
text_parts.append(msg_text)
doc_updated_at = (
datetime.fromisoformat(
announcement.posted_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if announcement.posted_at
else None
)
document = self._build_document(
doc_id=f"canvas-announcement-{announcement.course_id}-{announcement.id}",
link=announcement.html_url,
text="\n\n".join(text_parts),
semantic_identifier=announcement.title or f"Announcement {announcement.id}",
doc_updated_at=doc_updated_at,
course_id=announcement.course_id,
doc_type="announcement",
)
return document
@override
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load and validate Canvas credentials."""
access_token = credentials.get("canvas_access_token")
if not access_token:
raise ConnectorMissingCredentialError("Canvas")
try:
client = CanvasApiClient(
bearer_token=access_token,
canvas_base_url=self.canvas_base_url,
)
client.get("courses", params={"per_page": "1"})
except ValueError as e:
raise ConnectorValidationError(f"Invalid Canvas base URL: {e}")
except OnyxError as e:
_handle_canvas_api_error(e)
self._canvas_client = client
return None
@override
def validate_connector_settings(self) -> None:
"""Validate Canvas connector settings by testing API access."""
try:
self.canvas_client.get("courses", params={"per_page": "1"})
logger.info("Canvas connector settings validated successfully")
except OnyxError as e:
_handle_canvas_api_error(e)
except ConnectorMissingCredentialError:
raise
except Exception as exc:
raise UnexpectedValidationError(
f"Unexpected error during Canvas settings validation: {exc}"
)
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
# TODO(benwu408): implemented in PR4 (perm sync)
raise NotImplementedError

View File

@@ -11,13 +11,11 @@ from discord import Client
from discord.channel import TextChannel
from discord.channel import Thread
from discord.enums import MessageType
from discord.errors import LoginFailure
from discord.flags import Intents
from discord.message import Message as DiscordMessage
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import CredentialInvalidError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -211,19 +209,8 @@ def _manage_async_retrieval(
intents = Intents.default()
intents.message_content = True
async with Client(intents=intents) as discord_client:
start_task = asyncio.create_task(discord_client.start(token))
ready_task = asyncio.create_task(discord_client.wait_until_ready())
done, _ = await asyncio.wait(
{start_task, ready_task},
return_when=asyncio.FIRST_COMPLETED,
)
# start() runs indefinitely once connected, so it only lands
# in `done` when login/connection failed — propagate the error.
if start_task in done:
ready_task.cancel()
start_task.result()
asyncio.create_task(discord_client.start(token))
await discord_client.wait_until_ready()
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
discord_client=discord_client,
@@ -289,19 +276,6 @@ class DiscordConnector(PollConnector, LoadConnector):
self._discord_bot_token = credentials["discord_bot_token"]
return None
def validate_connector_settings(self) -> None:
loop = asyncio.new_event_loop()
try:
client = Client(intents=Intents.default())
try:
loop.run_until_complete(client.login(self.discord_bot_token))
except LoginFailure as e:
raise CredentialInvalidError(f"Invalid Discord bot token: {e}")
finally:
loop.run_until_complete(client.close())
finally:
loop.close()
def _manage_doc_batching(
self,
start: datetime | None = None,

View File

@@ -72,10 +72,6 @@ CONNECTOR_CLASS_MAP = {
module_path="onyx.connectors.coda.connector",
class_name="CodaConnector",
),
DocumentSource.CANVAS: ConnectorMapping(
module_path="onyx.connectors.canvas.connector",
class_name="CanvasConnector",
),
DocumentSource.NOTION: ConnectorMapping(
module_path="onyx.connectors.notion.connector",
class_name="NotionConnector",

View File

@@ -215,7 +215,6 @@ class UserFileStatus(str, PyEnum):
PROCESSING = "PROCESSING"
INDEXING = "INDEXING"
COMPLETED = "COMPLETED"
SKIPPED = "SKIPPED"
FAILED = "FAILED"
CANCELED = "CANCELED"
DELETING = "DELETING"

View File

@@ -7,7 +7,6 @@ from fastapi import HTTPException
from fastapi import UploadFile
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from sqlalchemy import func
from sqlalchemy.orm import Session
from starlette.background import BackgroundTasks
@@ -18,7 +17,6 @@ from onyx.configs.constants import FileOrigin
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.enums import UserFileStatus
from onyx.db.models import Project__UserFile
from onyx.db.models import User
from onyx.db.models import UserFile
@@ -36,19 +34,9 @@ class CategorizedFilesResult(BaseModel):
user_files: list[UserFile]
rejected_files: list[RejectedFile]
id_to_temp_id: dict[str, str]
# Filenames that should be stored but not indexed.
skip_indexing_filenames: set[str] = Field(default_factory=set)
# Allow SQLAlchemy ORM models inside this result container
model_config = ConfigDict(arbitrary_types_allowed=True)
@property
def indexable_files(self) -> list[UserFile]:
return [
uf
for uf in self.user_files
if (uf.name or "") not in self.skip_indexing_filenames
]
def build_hashed_file_key(file: UploadFile) -> str:
name_prefix = (file.filename or "")[:50]
@@ -82,7 +70,6 @@ def create_user_files(
)
if new_temp_id is not None:
id_to_temp_id[str(new_id)] = new_temp_id
should_skip = (file.filename or "") in categorized_files.skip_indexing
new_file = UserFile(
id=new_id,
user_id=user.id,
@@ -94,7 +81,6 @@ def create_user_files(
link_url=link_url,
content_type=file.content_type,
file_type=file.content_type,
status=UserFileStatus.SKIPPED if should_skip else UserFileStatus.PROCESSING,
last_accessed_at=datetime.datetime.now(datetime.timezone.utc),
)
# Persist the UserFile first to satisfy FK constraints for association table
@@ -112,7 +98,6 @@ def create_user_files(
user_files=user_files,
rejected_files=rejected_files,
id_to_temp_id=id_to_temp_id,
skip_indexing_filenames=categorized_files.skip_indexing,
)
@@ -138,7 +123,6 @@ def upload_files_to_user_files_with_indexing(
user_files = categorized_files_result.user_files
rejected_files = categorized_files_result.rejected_files
id_to_temp_id = categorized_files_result.id_to_temp_id
indexable_files = categorized_files_result.indexable_files
# Trigger per-file processing immediately for the current tenant
tenant_id = get_current_tenant_id()
for rejected_file in rejected_files:
@@ -150,12 +134,12 @@ def upload_files_to_user_files_with_indexing(
from onyx.background.task_utils import drain_processing_loop
background_tasks.add_task(drain_processing_loop, tenant_id)
for user_file in indexable_files:
for user_file in user_files:
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
else:
from onyx.background.celery.versioned_apps.client import app as client_app
for user_file in indexable_files:
for user_file in user_files:
task = client_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
@@ -171,7 +155,6 @@ def upload_files_to_user_files_with_indexing(
user_files=user_files,
rejected_files=rejected_files,
id_to_temp_id=id_to_temp_id,
skip_indexing_filenames=categorized_files_result.skip_indexing_filenames,
)

View File

@@ -932,7 +932,7 @@ class OpenSearchIndexClient(OpenSearchClient):
def search_for_document_ids(
self,
body: dict[str, Any],
search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN,
search_type: OpenSearchSearchType = OpenSearchSearchType.DOCUMENT_IDS,
) -> list[str]:
"""Searches the index and returns only document chunk IDs.

View File

@@ -60,7 +60,8 @@ class OpenSearchSearchType(str, Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
RANDOM = "random"
DOC_ID_RETRIEVAL = "doc_id_retrieval"
ID_RETRIEVAL = "id_retrieval"
DOCUMENT_IDS = "document_ids"
UNKNOWN = "unknown"

View File

@@ -928,7 +928,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_hits = self._client.search(
body=query_body,
search_pipeline_id=None,
search_type=OpenSearchSearchType.DOC_ID_RETRIEVAL,
search_type=OpenSearchSearchType.ID_RETRIEVAL,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(

View File

@@ -15,7 +15,6 @@ PLAIN_TEXT_MIME_TYPE = "text/plain"
class OnyxMimeTypes:
IMAGE_MIME_TYPES = {"image/jpg", "image/jpeg", "image/png", "image/webp"}
CSV_MIME_TYPES = {"text/csv"}
TABULAR_MIME_TYPES = CSV_MIME_TYPES | {SPREADSHEET_MIME_TYPE}
TEXT_MIME_TYPES = {
PLAIN_TEXT_MIME_TYPE,
"text/markdown",
@@ -35,12 +34,13 @@ class OnyxMimeTypes:
PDF_MIME_TYPE,
WORD_PROCESSING_MIME_TYPE,
PRESENTATION_MIME_TYPE,
SPREADSHEET_MIME_TYPE,
"message/rfc822",
"application/epub+zip",
}
ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, TABULAR_MIME_TYPES
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES
)
EXCLUDED_IMAGE_TYPES = {
@@ -53,11 +53,6 @@ class OnyxMimeTypes:
class OnyxFileExtensions:
TABULAR_EXTENSIONS = {
".csv",
".tsv",
".xlsx",
}
PLAIN_TEXT_EXTENSIONS = {
".txt",
".md",

View File

@@ -13,21 +13,15 @@ class ChatFileType(str, Enum):
DOC = "document"
# Plain text only contain the text
PLAIN_TEXT = "plain_text"
# Tabular data files (CSV, XLSX)
TABULAR = "tabular"
CSV = "csv"
def is_text_file(self) -> bool:
return self in (
ChatFileType.PLAIN_TEXT,
ChatFileType.DOC,
ChatFileType.TABULAR,
ChatFileType.CSV,
)
def use_metadata_only(self) -> bool:
"""File types where we can ignore the file content
and only use the metadata."""
return self in (ChatFileType.TABULAR,)
class FileDescriptor(TypedDict):
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column

View File

@@ -110,20 +110,16 @@ def load_user_file(file_id: UUID, db_session: Session) -> InMemoryChatFile:
# check for plain text normalized version first, then use original file otherwise
try:
file_io = file_store.read_file(plaintext_file_name, mode="b")
# Metadata-only file types preserve their original type so
# downstream injection paths can route them correctly.
if chat_file_type.use_metadata_only():
plaintext_chat_file_type = chat_file_type
elif file_io is not None:
# if we have plaintext for image (which happens when image
# extraction is enabled), we use PLAIN_TEXT type
# For plaintext versions, use PLAIN_TEXT type (unless it's an image which doesn't have plaintext)
plaintext_chat_file_type = (
ChatFileType.PLAIN_TEXT
if chat_file_type != ChatFileType.IMAGE
else chat_file_type
)
# if we have plaintext for image (which happens when image extraction is enabled), we use PLAIN_TEXT type
if file_io is not None:
plaintext_chat_file_type = ChatFileType.PLAIN_TEXT
else:
plaintext_chat_file_type = (
ChatFileType.PLAIN_TEXT
if chat_file_type != ChatFileType.IMAGE
else chat_file_type
)
chat_file = InMemoryChatFile(
file_id=str(user_file.file_id),

View File

@@ -1,3 +1,4 @@
from onyx.configs.app_configs import HOOK_ENABLED
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from shared_configs.configs import MULTI_TENANT
@@ -6,7 +7,10 @@ from shared_configs.configs import MULTI_TENANT
def require_hook_enabled() -> None:
"""FastAPI dependency that gates all hook management endpoints.
Hooks are only available in single-tenant / self-hosted EE deployments.
Hooks are only available in single-tenant / self-hosted deployments with
HOOK_ENABLED=true explicitly set. Two layers of protection:
1. MULTI_TENANT check — rejects even if HOOK_ENABLED is accidentally set true
2. HOOK_ENABLED flag — explicit opt-in by the operator
Use as: Depends(require_hook_enabled)
"""
@@ -15,3 +19,8 @@ def require_hook_enabled() -> None:
OnyxErrorCode.SINGLE_TENANT_ONLY,
"Hooks are not available in multi-tenant deployments",
)
if not HOOK_ENABLED:
raise OnyxError(
OnyxErrorCode.ENV_VAR_GATED,
"Hooks are not enabled. Set HOOK_ENABLED=true to enable.",
)

View File

@@ -1,22 +1,79 @@
"""CE hook executor.
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
HookSkipped and HookSoftFailed are real classes kept here because
process_message.py (CE code) uses isinstance checks against them.
Usage (Celery tasks and FastAPI handlers):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
response_type=QueryProcessingResponse,
)
execute_hook is the public entry point. It dispatches to _execute_hook_impl
via fetch_versioned_implementation so that:
- CE: onyx.hooks.executor._execute_hook_impl → no-op, returns HookSkipped()
- EE: ee.onyx.hooks.executor._execute_hook_impl → real HTTP call
if isinstance(result, HookSkipped):
# no active hook configured — continue with original behavior
...
elif isinstance(result, HookSoftFailed):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is a validated Pydantic model instance (response_type)
...
is_reachable update policy
--------------------------
``is_reachable`` on the Hook row is updated selectively — only when the outcome
carries meaningful signal about physical reachability:
NetworkError (DNS, connection refused) → False (cannot reach the server)
HTTP 401 / 403 → False (api_key revoked or invalid)
TimeoutException → None (server may be slow, skip write)
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
Unknown exception → None (no signal, skip write)
Non-JSON / non-dict response → None (server responded, skip write)
Success (2xx, valid dict) → True (confirmed reachable)
None means "leave the current value unchanged" — no DB round-trip is made.
DB session design
-----------------
The executor uses three sessions:
1. Caller's session (db_session) — used only for the hook lookup read. All
needed fields are extracted from the Hook object before the HTTP call, so
the caller's session is not held open during the external HTTP request.
2. Log session — a separate short-lived session opened after the HTTP call
completes to write the HookExecutionLog row on failure. Success runs are
not recorded. Committed independently of everything else.
3. Reachable session — a second short-lived session to update is_reachable on
the Hook. Kept separate from the log session so a concurrent hook deletion
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
prevent the execution log from being written. This update is best-effort.
"""
import json
import time
from typing import Any
from typing import TypeVar
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.db.hook import create_hook_execution_log__no_commit
from onyx.db.hook import get_non_deleted_hook_by_hook_point
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.utils.logger import setup_logger
logger = setup_logger()
class HookSkipped:
@@ -30,15 +87,277 @@ class HookSoftFailed:
T = TypeVar("T", bound=BaseModel)
def _execute_hook_impl(
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
class _HttpOutcome(BaseModel):
"""Structured result of an HTTP hook call, returned by _process_response."""
is_success: bool
updated_is_reachable: (
bool | None
) # True/False = write to DB, None = unchanged (skip write)
status_code: int | None
error_message: str | None
response_payload: dict[str, Any] | None
def _lookup_hook(
db_session: Session,
hook_point: HookPoint,
) -> Hook | HookSkipped:
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
No HTTP call is made and no DB writes are performed for any HookSkipped path.
There is nothing to log and no reachability information to update.
"""
if not HOOKS_AVAILABLE:
return HookSkipped()
hook = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if hook is None or not hook.is_active:
return HookSkipped()
if not hook.endpoint_url:
return HookSkipped()
return hook
def _process_response(
*,
db_session: Session, # noqa: ARG001
hook_point: HookPoint, # noqa: ARG001
payload: dict[str, Any], # noqa: ARG001
response_type: type[T], # noqa: ARG001
) -> T | HookSkipped | HookSoftFailed:
"""CE no-op — hooks are not available without EE."""
return HookSkipped()
response: httpx.Response | None,
exc: Exception | None,
timeout: float,
) -> _HttpOutcome:
"""Process the result of an HTTP call and return a structured outcome.
Called after the client.post() try/except. If post() raised, exc is set and
response is None. Otherwise response is set and exc is None. Handles
raise_for_status(), JSON decoding, and the dict shape check.
"""
if exc is not None:
if isinstance(exc, httpx.NetworkError):
msg = f"Hook network error (endpoint unreachable): {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False,
status_code=None,
error_message=msg,
response_payload=None,
)
if isinstance(exc, httpx.TimeoutException):
msg = f"Hook timed out after {timeout}s: {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # timeout doesn't indicate unreachability
status_code=None,
error_message=msg,
response_payload=None,
)
msg = f"Hook call failed: {exc}"
logger.exception(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # unknown error — don't make assumptions
status_code=None,
error_message=msg,
response_payload=None,
)
if response is None:
raise ValueError(
"exactly one of response or exc must be non-None; both are None"
)
status_code = response.status_code
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
logger.warning(msg, exc_info=e)
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
# so the operator knows to update it. All other HTTP errors keep is_reachable
# as-is (server is up, the request just failed for application reasons).
auth_failed = e.response.status_code in (401, 403)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False if auth_failed else None,
status_code=status_code,
error_message=msg,
response_payload=None,
)
try:
response_payload = response.json()
except (json.JSONDecodeError, httpx.DecodingError) as e:
msg = f"Hook returned non-JSON response: {e}"
logger.warning(msg, exc_info=e)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
if not isinstance(response_payload, dict):
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
logger.warning(msg)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
return _HttpOutcome(
is_success=True,
updated_is_reachable=True,
status_code=status_code,
error_message=None,
response_payload=response_payload,
)
def _persist_result(
*,
hook_id: int,
outcome: _HttpOutcome,
duration_ms: int,
) -> None:
"""Write the execution log on failure and optionally update is_reachable, each
in its own session so a failure in one does not affect the other."""
# Only write the execution log on failure — success runs are not recorded.
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
# deleted between the initial lookup and here).
if not outcome.is_success:
try:
with get_session_with_current_tenant() as log_session:
create_hook_execution_log__no_commit(
db_session=log_session,
hook_id=hook_id,
is_success=False,
error_message=outcome.error_message,
status_code=outcome.status_code,
duration_ms=duration_ms,
)
log_session.commit()
except Exception:
logger.exception(
f"Failed to persist hook execution log for hook_id={hook_id}"
)
# Update is_reachable separately — best-effort, non-critical.
# None means the value is unchanged (set by the caller to skip the no-op write).
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
# concurrently deleted, so keep this isolated from the log write above.
if outcome.updated_is_reachable is not None:
try:
with get_session_with_current_tenant() as reachable_session:
update_hook__no_commit(
db_session=reachable_session,
hook_id=hook_id,
is_reachable=outcome.updated_is_reachable,
)
reachable_session.commit()
except Exception:
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def _execute_hook_inner(
hook: Hook,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
"active hooks without an endpoint_url must be rejected by _lookup_hook"
)
start = time.monotonic()
response: httpx.Response | None = None
exc: Exception | None = None
try:
api_key: str | None = (
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
)
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against response_type.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: T | None = None
if outcome.is_success and outcome.response_payload is not None:
try:
validated_model = response_type.model_validate(outcome.response_payload)
except ValidationError as e:
msg = (
f"Hook response failed validation against {response_type.__name__}: {e}"
)
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
outcome = outcome.model_copy(update={"updated_is_reachable": None})
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
if not outcome.is_success:
if fail_strategy == HookFailStrategy.HARD:
raise OnyxError(
OnyxErrorCode.HOOK_EXECUTION_FAILED,
outcome.error_message or "Hook execution failed.",
)
logger.warning(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if validated_model is None:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"validated_model is None for successful hook call (hook_id={hook_id})",
)
return validated_model
def execute_hook(
@@ -48,15 +367,25 @@ def execute_hook(
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point.
"""Execute the hook for the given hook point synchronously.
Dispatches to the versioned implementation so EE gets the real executor
and CE gets the no-op stub, without any changes at the call site.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
"""
impl = fetch_versioned_implementation("onyx.hooks.executor", "_execute_hook_impl")
return impl(
db_session=db_session,
hook_point=hook_point,
payload=payload,
response_type=response_type,
)
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
fail_strategy = hook.fail_strategy
hook_id = hook.id
try:
return _execute_hook_inner(hook, payload, response_type)
except Exception:
if fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
)
return HookSoftFailed()
raise

View File

@@ -0,0 +1,5 @@
from onyx.configs.app_configs import HOOK_ENABLED
from shared_configs.configs import MULTI_TENANT
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT

View File

@@ -77,6 +77,7 @@ from onyx.server.features.default_assistant.api import (
)
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.hierarchy.api import router as hierarchy_router
from onyx.server.features.hooks.api import router as hook_router
from onyx.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
@@ -438,7 +439,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:
@@ -454,6 +454,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
register_onyx_exception_handlers(application)
include_router_with_global_prefix_prepended(application, hook_router)
include_router_with_global_prefix_prepended(application, password_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, query_router)

View File

@@ -76,18 +76,11 @@ class CategorizedFiles(BaseModel):
acceptable: list[UploadFile] = Field(default_factory=list)
rejected: list[RejectedFile] = Field(default_factory=list)
acceptable_file_to_token_count: dict[str, int] = Field(default_factory=dict)
# Filenames within `acceptable` that should be stored but not indexed.
skip_indexing: set[str] = Field(default_factory=set)
# Allow FastAPI UploadFile instances
model_config = ConfigDict(arbitrary_types_allowed=True)
def _skip_token_threshold(extension: str) -> bool:
"""Return True if this file extension should bypass the token limit."""
return extension.lower() in OnyxFileExtensions.TABULAR_EXTENSIONS
def _apply_long_side_cap(width: int, height: int, cap: int) -> tuple[int, int]:
if max(width, height) <= cap:
return width, height
@@ -271,17 +264,7 @@ def categorize_uploaded_files(
token_count = count_tokens(
text_content, tokenizer, token_limit=token_threshold
)
exceeds_threshold = (
token_threshold is not None and token_count > token_threshold
)
if exceeds_threshold and _skip_token_threshold(extension):
# Exempt extensions (e.g. spreadsheets) are accepted
# but flagged to skip indexing — only metadata is
# injected into the LLM context.
results.acceptable.append(upload)
results.acceptable_file_to_token_count[filename] = token_count
results.skip_indexing.add(filename)
elif exceeds_threshold:
if token_threshold is not None and token_count > token_threshold:
results.rejected.append(
RejectedFile(
filename=filename,

View File

@@ -9,8 +9,8 @@ def mime_type_to_chat_file_type(mime_type: str | None) -> ChatFileType:
if mime_type in OnyxMimeTypes.IMAGE_MIME_TYPES:
return ChatFileType.IMAGE
if mime_type in OnyxMimeTypes.TABULAR_MIME_TYPES:
return ChatFileType.TABULAR
if mime_type in OnyxMimeTypes.CSV_MIME_TYPES:
return ChatFileType.CSV
if mime_type in OnyxMimeTypes.DOCUMENT_MIME_TYPES:
return ChatFileType.DOC

View File

@@ -21,6 +21,7 @@ from onyx.db.notification import get_notifications
from onyx.db.notification import update_notification_last_shown
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.features.build.utils import is_onyx_craft_enabled
@@ -37,7 +38,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -98,7 +98,7 @@ def fetch_settings(
needs_reindexing=needs_reindexing,
onyx_craft_enabled=onyx_craft_enabled_for_user,
vector_db_enabled=not DISABLE_VECTOR_DB,
hooks_enabled=not MULTI_TENANT,
hooks_enabled=HOOKS_AVAILABLE,
version=onyx_version,
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
default_user_file_max_upload_size_mb=min(

View File

@@ -116,7 +116,7 @@ class UserSettings(Settings):
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
# document sets are unavailable.
vector_db_enabled: bool = True
# True when hooks are available: single-tenant EE deployments only.
# True when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
hooks_enabled: bool = False
# Application version, read from the ONYX_VERSION env var at startup.
version: str | None = None

View File

@@ -1,4 +1,3 @@
import io
import json
from typing import Any
from typing import cast
@@ -10,7 +9,6 @@ from typing_extensions import override
from onyx.chat.emitter import Emitter
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_chat_file_by_id
@@ -171,13 +169,10 @@ class FileReaderTool(Tool[FileReaderToolOverrideKwargs]):
chat_file = self._load_file(file_id)
# Only PLAIN_TEXT and TABULAR are guaranteed to contain actual text bytes.
# Only PLAIN_TEXT and CSV are guaranteed to contain actual text bytes.
# DOC type in a loaded file means plaintext extraction failed and the
# content is the original binary (e.g. raw PDF/DOCX bytes).
if chat_file.file_type not in (
ChatFileType.PLAIN_TEXT,
ChatFileType.TABULAR,
):
if chat_file.file_type not in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV):
raise ToolCallException(
message=f"File {file_id} is not a text file (type={chat_file.file_type})",
llm_facing_message=(
@@ -186,19 +181,7 @@ class FileReaderTool(Tool[FileReaderToolOverrideKwargs]):
)
try:
if chat_file.file_type == ChatFileType.PLAIN_TEXT:
full_text = chat_file.content.decode("utf-8", errors="replace")
else:
full_text = (
extract_file_text(
file=io.BytesIO(chat_file.content),
file_name=chat_file.filename or "",
break_on_unprocessable=False,
)
or ""
)
except ToolCallException:
raise
full_text = chat_file.content.decode("utf-8", errors="replace")
except Exception:
raise ToolCallException(
message=f"Failed to decode file {file_id}",

View File

@@ -5,7 +5,6 @@ import asyncio
import json
import logging
import sys
import time
from dataclasses import asdict
from dataclasses import dataclass
from pathlib import Path
@@ -28,9 +27,6 @@ INTERNAL_SEARCH_TOOL_NAME = "internal_search"
INTERNAL_SEARCH_IN_CODE_TOOL_ID = "SearchTool"
MAX_REQUEST_ATTEMPTS = 5
RETRIABLE_STATUS_CODES = {429, 500, 502, 503, 504}
QUESTION_TIMEOUT_SECONDS = 300
QUESTION_RETRY_PAUSE_SECONDS = 30
MAX_QUESTION_ATTEMPTS = 3
@dataclass(frozen=True)
@@ -113,27 +109,6 @@ def normalize_api_base(api_base: str) -> str:
return f"{normalized}/api"
def load_completed_question_ids(output_file: Path) -> set[str]:
if not output_file.exists():
return set()
completed_ids: set[str] = set()
with output_file.open("r", encoding="utf-8") as file:
for line in file:
stripped = line.strip()
if not stripped:
continue
try:
record = json.loads(stripped)
except json.JSONDecodeError:
continue
question_id = record.get("question_id")
if isinstance(question_id, str) and question_id:
completed_ids.add(question_id)
return completed_ids
def load_questions(questions_file: Path) -> list[QuestionRecord]:
if not questions_file.exists():
raise FileNotFoundError(f"Questions file not found: {questions_file}")
@@ -373,7 +348,6 @@ async def generate_answers(
api_base: str,
api_key: str,
parallelism: int,
skipped: int,
) -> None:
if parallelism < 1:
raise ValueError("`--parallelism` must be at least 1.")
@@ -408,178 +382,58 @@ async def generate_answers(
write_lock = asyncio.Lock()
completed = 0
successful = 0
stuck_count = 0
failed_questions: list[FailedQuestionRecord] = []
remaining_count = len(questions)
overall_total = remaining_count + skipped
question_durations: list[float] = []
run_start_time = time.monotonic()
def print_progress() -> None:
avg_time = (
sum(question_durations) / len(question_durations)
if question_durations
else 0.0
)
elapsed = time.monotonic() - run_start_time
eta = avg_time * (remaining_count - completed) / max(parallelism, 1)
done = skipped + completed
bar_width = 30
filled = (
int(bar_width * done / overall_total)
if overall_total
else bar_width
)
bar = "" * filled + "" * (bar_width - filled)
pct = (done / overall_total * 100) if overall_total else 100.0
parts = (
f"\r{bar} {pct:5.1f}% "
f"[{done}/{overall_total}] "
f"avg {avg_time:.1f}s/q "
f"elapsed {elapsed:.0f}s "
f"ETA {eta:.0f}s "
f"(ok:{successful} fail:{len(failed_questions)}"
)
if stuck_count:
parts += f" stuck:{stuck_count}"
if skipped:
parts += f" skip:{skipped}"
parts += ")"
sys.stderr.write(parts)
sys.stderr.flush()
print_progress()
total = len(questions)
async def process_question(question_record: QuestionRecord) -> None:
nonlocal completed
nonlocal successful
nonlocal stuck_count
last_error: Exception | None = None
for attempt in range(1, MAX_QUESTION_ATTEMPTS + 1):
q_start = time.monotonic()
try:
async with semaphore:
result = await asyncio.wait_for(
submit_question(
session=session,
api_base=api_base,
headers=headers,
internal_search_tool_id=internal_search_tool_id,
question_record=question_record,
),
timeout=QUESTION_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
async with progress_lock:
stuck_count += 1
logger.warning(
"Question %s timed out after %ss (attempt %s/%s, "
"total stuck: %s) — retrying in %ss",
question_record.question_id,
QUESTION_TIMEOUT_SECONDS,
attempt,
MAX_QUESTION_ATTEMPTS,
stuck_count,
QUESTION_RETRY_PAUSE_SECONDS,
)
print_progress()
last_error = TimeoutError(
f"Timed out after {QUESTION_TIMEOUT_SECONDS}s "
f"on attempt {attempt}/{MAX_QUESTION_ATTEMPTS}"
try:
async with semaphore:
result = await submit_question(
session=session,
api_base=api_base,
headers=headers,
internal_search_tool_id=internal_search_tool_id,
question_record=question_record,
)
await asyncio.sleep(QUESTION_RETRY_PAUSE_SECONDS)
continue
except Exception as exc:
duration = time.monotonic() - q_start
async with progress_lock:
completed += 1
question_durations.append(duration)
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(exc),
)
)
logger.exception(
"Failed question %s (%s/%s)",
question_record.question_id,
completed,
remaining_count,
)
print_progress()
return
duration = time.monotonic() - q_start
async with write_lock:
file.write(json.dumps(asdict(result), ensure_ascii=False))
file.write("\n")
file.flush()
except Exception as exc:
async with progress_lock:
completed += 1
successful += 1
question_durations.append(duration)
print_progress()
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(exc),
)
)
logger.exception(
"Failed question %s (%s/%s)",
question_record.question_id,
completed,
total,
)
return
# All attempts exhausted due to timeouts
async with write_lock:
file.write(json.dumps(asdict(result), ensure_ascii=False))
file.write("\n")
file.flush()
async with progress_lock:
completed += 1
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(last_error),
)
)
logger.error(
"Question %s failed after %s timeout attempts (%s/%s)",
question_record.question_id,
MAX_QUESTION_ATTEMPTS,
completed,
remaining_count,
)
print_progress()
successful += 1
logger.info("Processed %s/%s questions", completed, total)
await asyncio.gather(
*(process_question(question_record) for question_record in questions)
)
# Final newline after progress bar
sys.stderr.write("\n")
sys.stderr.flush()
total_elapsed = time.monotonic() - run_start_time
avg_time = (
sum(question_durations) / len(question_durations)
if question_durations
else 0.0
)
stuck_suffix = f", {stuck_count} stuck timeouts" if stuck_count else ""
resume_suffix = (
f"{skipped} previously completed, "
f"{skipped + successful}/{overall_total} overall"
if skipped
else ""
)
logger.info(
"Done: %s/%s successful in %.1fs (avg %.1fs/question%s)%s",
successful,
remaining_count,
total_elapsed,
avg_time,
stuck_suffix,
resume_suffix,
)
if failed_questions:
logger.warning(
"%s questions failed:",
"Completed with %s failed questions and %s successful questions.",
len(failed_questions),
successful,
)
for failed_question in failed_questions:
logger.warning(
@@ -599,30 +453,7 @@ def main() -> None:
raise ValueError("`--max-questions` must be at least 1 when provided.")
questions = questions[: args.max_questions]
completed_ids = load_completed_question_ids(args.output_file)
logger.info(
"Found %s already-answered question IDs in %s",
len(completed_ids),
args.output_file,
)
total_before_filter = len(questions)
questions = [q for q in questions if q.question_id not in completed_ids]
skipped = total_before_filter - len(questions)
if skipped:
logger.info(
"Resuming: %s/%s already answered, %s remaining",
skipped,
total_before_filter,
len(questions),
)
else:
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
if not questions:
logger.info("All questions already answered. Nothing to do.")
return
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
logger.info("Writing answers to %s", args.output_file)
asyncio.run(
@@ -632,7 +463,6 @@ def main() -> None:
api_base=api_base,
api_key=args.api_key,
parallelism=args.parallelism,
skipped=skipped,
)
)

View File

@@ -1175,7 +1175,7 @@ def test_code_interpreter_receives_chat_files(
file_descriptor: FileDescriptor = {
"id": user_file.file_id,
"type": ChatFileType.TABULAR,
"type": ChatFileType.CSV,
"name": "data.csv",
"user_file_id": str(user_file.id),
}

View File

@@ -1,9 +1,3 @@
import mimetypes
from typing import Any
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.file import FileManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
@@ -85,90 +79,3 @@ def test_send_message_with_text_file_attachment(admin_user: DATestUser) -> None:
assert (
"third line" in response.full_message.lower()
), "Chat response should contain the contents of the file"
def _set_token_threshold(admin_user: DATestUser, threshold_k: int) -> None:
"""Set the file token count threshold via admin settings API."""
response = requests.put(
f"{API_SERVER_URL}/admin/settings",
json={"file_token_count_threshold_k": threshold_k},
headers=admin_user.headers,
)
response.raise_for_status()
def _upload_raw(
filename: str,
content: bytes,
user: DATestUser,
) -> dict[str, Any]:
"""Upload a file and return the full JSON response (user_files + rejected_files)."""
mime_type, _ = mimetypes.guess_type(filename)
headers = user.headers.copy()
headers.pop("Content-Type", None)
response = requests.post(
f"{API_SERVER_URL}/user/projects/file/upload",
files=[("files", (filename, content, mime_type or "application/octet-stream"))],
headers=headers,
)
response.raise_for_status()
return response.json()
def test_csv_over_token_threshold_uploaded_not_indexed(
admin_user: DATestUser,
) -> None:
"""CSV exceeding token threshold is uploaded (accepted) but skips indexing."""
_set_token_threshold(admin_user, threshold_k=1)
try:
# ~2000 tokens with default tokenizer, well over 1K threshold
content = ("x " * 100 + "\n") * 20
result = _upload_raw("large.csv", content.encode(), admin_user)
assert len(result["user_files"]) == 1, "CSV should be accepted"
assert len(result["rejected_files"]) == 0, "CSV should not be rejected"
assert (
result["user_files"][0]["status"] == "SKIPPED"
), "CSV over threshold should be SKIPPED (uploaded but not indexed)"
assert (
result["user_files"][0]["chunk_count"] is None
), "Skipped file should have no chunks"
finally:
_set_token_threshold(admin_user, threshold_k=200)
def test_csv_under_token_threshold_uploaded_and_indexed(
admin_user: DATestUser,
) -> None:
"""CSV under token threshold is uploaded and queued for indexing."""
_set_token_threshold(admin_user, threshold_k=200)
try:
content = "col1,col2\na,b\n"
result = _upload_raw("small.csv", content.encode(), admin_user)
assert len(result["user_files"]) == 1, "CSV should be accepted"
assert len(result["rejected_files"]) == 0, "CSV should not be rejected"
assert (
result["user_files"][0]["status"] == "PROCESSING"
), "CSV under threshold should be PROCESSING (queued for indexing)"
finally:
_set_token_threshold(admin_user, threshold_k=200)
def test_txt_over_token_threshold_rejected(
admin_user: DATestUser,
) -> None:
"""Non-exempt file exceeding token threshold is rejected entirely."""
_set_token_threshold(admin_user, threshold_k=1)
try:
# ~2000 tokens, well over 1K threshold. Unlike CSV, .txt is not
# exempt from the threshold so the file should be rejected.
content = ("x " * 100 + "\n") * 20
result = _upload_raw("big.txt", content.encode(), admin_user)
assert len(result["user_files"]) == 0, "File should not be accepted"
assert len(result["rejected_files"]) == 1, "File should be rejected"
assert "token limit" in result["rejected_files"][0]["reason"].lower()
finally:
_set_token_threshold(admin_user, threshold_k=200)

View File

@@ -300,66 +300,6 @@ class TestExtractContextFiles:
assert result.file_texts == []
assert result.total_token_count == 50
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_tool_metadata_file_id_matches_chat_history_file_id(
self, mock_load: MagicMock
) -> None:
"""The file_id in tool metadata (from extract_context_files) and the
file_id in chat history messages (from build_file_context) must
agree, otherwise the LLM sees different IDs for the same file across
turns.
In production, UserFile.id (UUID PK) differs from UserFile.file_id
(file-store path). Both pathways should produce the same file_id
(UserFile.id) for FileReaderTool."""
from onyx.chat.chat_utils import build_file_context
user_file_uuid = uuid4()
file_store_path = f"user_files/{user_file_uuid}/data.csv"
uf = UserFile(
id=user_file_uuid,
file_id=file_store_path,
name="data.csv",
token_count=100,
file_type="text/csv",
)
in_memory = InMemoryChatFile(
file_id=file_store_path,
content=b"col1,col2\na,b",
file_type=ChatFileType.TABULAR,
filename="data.csv",
)
mock_load.return_value = [in_memory]
# Pathway 1: extract_context_files (project/persona context)
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert len(result.file_metadata_for_tool) == 1
tool_metadata_file_id = result.file_metadata_for_tool[0].file_id
# Pathway 2: build_file_context (chat history path)
# In convert_chat_history, tool_file_id comes from
# file_descriptor["user_file_id"], which is str(UserFile.id)
ctx = build_file_context(
tool_file_id=str(user_file_uuid),
filename="data.csv",
file_type=ChatFileType.TABULAR,
)
chat_history_file_id = ctx.tool_metadata.file_id
# Both pathways must produce the same ID for the LLM
assert tool_metadata_file_id == chat_history_file_id, (
f"File ID mismatch: extract_context_files uses '{tool_metadata_file_id}' "
f"but build_file_context uses '{chat_history_file_id}'."
)
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
def test_overflow_with_vector_db_disabled_provides_tool_metadata(self) -> None:
"""When vector DB is disabled, overflow produces FileToolMetadata."""
@@ -376,128 +316,6 @@ class TestExtractContextFiles:
assert len(result.file_metadata_for_tool) == 1
assert result.file_metadata_for_tool[0].filename == "bigfile.txt"
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_metadata_only_files_not_counted_in_aggregate_tokens(
self, mock_load: MagicMock
) -> None:
"""Metadata-only files (TABULAR) should not count toward the token budget."""
text_file_id = str(uuid4())
text_uf = _make_user_file(token_count=100, file_id=text_file_id)
# TABULAR file with large token count — should be excluded from aggregate
tabular_uf = _make_user_file(
token_count=50000, name="huge.xlsx", file_id=str(uuid4())
)
tabular_uf.file_type = (
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
mock_load.return_value = [
_make_in_memory_file(file_id=text_file_id, content="text content"),
InMemoryChatFile(
file_id=str(tabular_uf.id),
content=b"binary xlsx",
file_type=ChatFileType.TABULAR,
filename="huge.xlsx",
),
]
result = extract_context_files(
user_files=[text_uf, tabular_uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
# Text file fits (100 < 6000), so files should be loaded
assert result.file_texts == ["text content"]
# TABULAR file should appear as tool metadata, not in file_texts
assert len(result.file_metadata_for_tool) == 1
assert result.file_metadata_for_tool[0].filename == "huge.xlsx"
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_metadata_only_files_loaded_as_tool_metadata(
self, mock_load: MagicMock
) -> None:
"""When files fit, metadata-only files appear in file_metadata_for_tool."""
text_file_id = str(uuid4())
tabular_file_id = str(uuid4())
text_uf = _make_user_file(token_count=100, file_id=text_file_id)
tabular_uf = _make_user_file(
token_count=500, name="data.csv", file_id=tabular_file_id
)
tabular_uf.file_type = "text/csv"
mock_load.return_value = [
_make_in_memory_file(file_id=text_file_id, content="hello"),
InMemoryChatFile(
file_id=tabular_file_id,
content=b"col1,col2\na,b",
file_type=ChatFileType.TABULAR,
filename="data.csv",
),
]
result = extract_context_files(
user_files=[text_uf, tabular_uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.file_texts == ["hello"]
assert len(result.file_metadata_for_tool) == 1
assert result.file_metadata_for_tool[0].filename == "data.csv"
# TABULAR should not appear in file_metadata (that's for citation)
assert all(m.filename != "data.csv" for m in result.file_metadata)
def test_overflow_with_vector_db_preserves_metadata_only_tool_metadata(
self,
) -> None:
"""When text files overflow with vector DB enabled, metadata-only files
should still be exposed via file_metadata_for_tool since they aren't
in the vector DB and would otherwise be inaccessible."""
text_uf = _make_user_file(token_count=7000, name="bigfile.txt")
tabular_uf = _make_user_file(token_count=500, name="data.xlsx")
tabular_uf.file_type = (
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
result = extract_context_files(
user_files=[text_uf, tabular_uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
# Text files overflow → search filter enabled
assert result.use_as_search_filter is True
assert result.file_texts == []
# TABULAR file should still be in tool metadata
assert len(result.file_metadata_for_tool) == 1
assert result.file_metadata_for_tool[0].filename == "data.xlsx"
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
def test_overflow_no_vector_db_includes_all_files_in_tool_metadata(self) -> None:
"""When vector DB is disabled and files overflow, all files
(both text and metadata-only) appear in file_metadata_for_tool."""
text_uf = _make_user_file(token_count=7000, name="bigfile.txt")
tabular_uf = _make_user_file(token_count=500, name="data.xlsx")
tabular_uf.file_type = (
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
result = extract_context_files(
user_files=[text_uf, tabular_uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.use_as_search_filter is False
assert len(result.file_metadata_for_tool) == 2
filenames = {m.filename for m in result.file_metadata_for_tool}
assert filenames == {"bigfile.txt", "data.xlsx"}
# ===========================================================================
# Search filter + search_usage determination

View File

@@ -644,92 +644,6 @@ class TestConstructMessageHistory:
assert "Project file 0 content" in project_message.message
assert "Project file 1 content" in project_message.message
def test_file_metadata_for_tool_produces_message(self) -> None:
"""When context_files has file_metadata_for_tool, a metadata listing
message should be injected into the history."""
system_prompt = create_message("System", MessageType.SYSTEM, 10)
user_msg = create_message("Analyze the spreadsheet", MessageType.USER, 5)
context_files = ExtractedContextFiles(
file_texts=[],
image_files=[],
use_as_search_filter=False,
total_token_count=0,
file_metadata=[],
uncapped_token_count=0,
file_metadata_for_tool=[
FileToolMetadata(
file_id="xlsx-1",
filename="report.xlsx",
approx_char_count=100000,
),
],
)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=[user_msg],
reminder_message=None,
context_files=context_files,
available_tokens=1000,
token_counter=_simple_token_counter,
)
# Should have: system, tool_metadata_message, user
assert len(result) == 3
metadata_msg = result[1]
assert metadata_msg.message_type == MessageType.USER
assert "report.xlsx" in metadata_msg.message
assert "xlsx-1" in metadata_msg.message
def test_metadata_only_and_text_files_both_present(self) -> None:
"""When both text content and tool metadata are present, both messages
should appear in the history."""
system_prompt = create_message("System", MessageType.SYSTEM, 10)
user_msg = create_message("Summarize everything", MessageType.USER, 5)
context_files = ExtractedContextFiles(
file_texts=["Text file content here"],
image_files=[],
use_as_search_filter=False,
total_token_count=100,
file_metadata=[
ContextFileMetadata(
file_id="txt-1",
filename="notes.txt",
file_content="Text file content here",
),
],
uncapped_token_count=100,
file_metadata_for_tool=[
FileToolMetadata(
file_id="xlsx-1",
filename="data.xlsx",
approx_char_count=50000,
),
],
)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=[user_msg],
reminder_message=None,
context_files=context_files,
available_tokens=2000,
token_counter=_simple_token_counter,
)
# Should have: system, context_files_message, tool_metadata_message, user
assert len(result) == 4
# Context files message (text content)
assert "documents" in result[1].message
assert "Text file content here" in result[1].message
# Tool metadata message
assert "data.xlsx" in result[2].message
assert result[3] == user_msg
def _simple_token_counter(text: str) -> int:
"""Approximate token counter for tests (~4 chars per token)."""

View File

@@ -139,7 +139,7 @@ def test_csv_file_type() -> None:
result = _extract_referenced_file_descriptors([tool_call], message)
assert len(result) == 1
assert result[0]["type"] == ChatFileType.TABULAR
assert result[0]["type"] == ChatFileType.CSV
def test_unknown_extension_defaults_to_plain_text() -> None:

View File

@@ -1,23 +1,15 @@
"""Tests for Canvas connector — client, credentials, conversion."""
"""Tests for Canvas connector — client (PR1)."""
from datetime import datetime
from datetime import timezone
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.connectors.canvas.connector import CanvasConnector
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.error_handling.exceptions import OnyxError
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@@ -26,77 +18,6 @@ FAKE_BASE_URL = "https://myschool.instructure.com"
FAKE_TOKEN = "fake-canvas-token"
def _mock_course(
course_id: int = 1,
name: str = "Intro to CS",
course_code: str = "CS101",
) -> dict[str, Any]:
return {
"id": course_id,
"name": name,
"course_code": course_code,
"created_at": "2025-01-01T00:00:00Z",
"workflow_state": "available",
}
def _build_connector(base_url: str = FAKE_BASE_URL) -> CanvasConnector:
"""Build a connector with mocked credential validation."""
with patch("onyx.connectors.canvas.client.rl_requests") as mock_req:
mock_req.get.return_value = _mock_response(json_data=[_mock_course()])
connector = CanvasConnector(canvas_base_url=base_url)
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
return connector
def _mock_page(
page_id: int = 10,
title: str = "Syllabus",
updated_at: str = "2025-06-01T12:00:00Z",
) -> dict[str, Any]:
return {
"page_id": page_id,
"url": "syllabus",
"title": title,
"body": "<p>Welcome to the course</p>",
"created_at": "2025-01-15T00:00:00Z",
"updated_at": updated_at,
}
def _mock_assignment(
assignment_id: int = 20,
name: str = "Homework 1",
course_id: int = 1,
updated_at: str = "2025-06-01T12:00:00Z",
) -> dict[str, Any]:
return {
"id": assignment_id,
"name": name,
"description": "<p>Solve these problems</p>",
"html_url": f"{FAKE_BASE_URL}/courses/{course_id}/assignments/{assignment_id}",
"course_id": course_id,
"created_at": "2025-01-20T00:00:00Z",
"updated_at": updated_at,
"due_at": "2025-02-01T23:59:00Z",
}
def _mock_announcement(
announcement_id: int = 30,
title: str = "Class Cancelled",
course_id: int = 1,
posted_at: str = "2025-06-01T12:00:00Z",
) -> dict[str, Any]:
return {
"id": announcement_id,
"title": title,
"message": "<p>No class today</p>",
"html_url": f"{FAKE_BASE_URL}/courses/{course_id}/discussion_topics/{announcement_id}",
"posted_at": posted_at,
}
def _mock_response(
status_code: int = 200,
json_data: Any = None,
@@ -404,57 +325,6 @@ class TestGet:
assert result == expected
# ---------------------------------------------------------------------------
# CanvasApiClient.paginate tests
# ---------------------------------------------------------------------------
class TestPaginate:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[{"id": 1}, {"id": 2}]
)
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
pages = list(client.paginate("courses"))
assert len(pages) == 1
assert pages[0] == [{"id": 1}, {"id": 2}]
@patch("onyx.connectors.canvas.client.rl_requests")
def test_two_pages(self, mock_requests: MagicMock) -> None:
next_link = f'<{FAKE_BASE_URL}/api/v1/courses?page=2>; rel="next"'
page1 = _mock_response(json_data=[{"id": 1}], link_header=next_link)
page2 = _mock_response(json_data=[{"id": 2}])
mock_requests.get.side_effect = [page1, page2]
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
pages = list(client.paginate("courses"))
assert len(pages) == 2
assert pages[0] == [{"id": 1}]
assert pages[1] == [{"id": 2}]
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
pages = list(client.paginate("courses"))
assert pages == []
# ---------------------------------------------------------------------------
# CanvasApiClient._parse_next_link tests
# ---------------------------------------------------------------------------
@@ -509,368 +379,3 @@ class TestParseNextLink:
with pytest.raises(OnyxError, match="must use https"):
self.client._parse_next_link(header)
# ---------------------------------------------------------------------------
# CanvasConnector — credential loading
# ---------------------------------------------------------------------------
class TestLoadCredentials:
def _assert_load_credentials_raises(
self,
status_code: int,
expected_error: type[Exception],
mock_requests: MagicMock,
) -> None:
"""Helper: assert load_credentials raises expected_error for a given status."""
mock_requests.get.return_value = _mock_response(status_code, {})
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
with pytest.raises(expected_error):
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
@patch("onyx.connectors.canvas.client.rl_requests")
def test_load_credentials_success(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[_mock_course()])
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
result = connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
assert result is None
assert connector._canvas_client is not None
def test_canvas_client_raises_without_credentials(self) -> None:
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
with pytest.raises(ConnectorMissingCredentialError):
_ = connector.canvas_client
@patch("onyx.connectors.canvas.client.rl_requests")
def test_load_credentials_invalid_token(self, mock_requests: MagicMock) -> None:
self._assert_load_credentials_raises(401, CredentialExpiredError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_load_credentials_insufficient_permissions(
self, mock_requests: MagicMock
) -> None:
self._assert_load_credentials_raises(
403, InsufficientPermissionsError, mock_requests
)
# ---------------------------------------------------------------------------
# CanvasConnector — URL normalization
# ---------------------------------------------------------------------------
class TestConnectorUrlNormalization:
def test_strips_api_v1_suffix(self) -> None:
connector = _build_connector(base_url=f"{FAKE_BASE_URL}/api/v1")
result = connector.canvas_base_url
expected = FAKE_BASE_URL
assert result == expected
def test_strips_trailing_slash(self) -> None:
connector = _build_connector(base_url=f"{FAKE_BASE_URL}/")
result = connector.canvas_base_url
expected = FAKE_BASE_URL
assert result == expected
def test_no_change_for_clean_url(self) -> None:
connector = _build_connector(base_url=FAKE_BASE_URL)
result = connector.canvas_base_url
expected = FAKE_BASE_URL
assert result == expected
# ---------------------------------------------------------------------------
# CanvasConnector — document conversion
# ---------------------------------------------------------------------------
class TestDocumentConversion:
def setup_method(self) -> None:
self.connector = _build_connector()
def test_convert_page_to_document(self) -> None:
from onyx.connectors.canvas.connector import CanvasPage
page = CanvasPage(
page_id=10,
url="syllabus",
title="Syllabus",
body="<p>Welcome</p>",
created_at="2025-01-15T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
course_id=1,
)
doc = self.connector._convert_page_to_document(page)
expected_id = "canvas-page-1-10"
expected_metadata = {"course_id": "1", "type": "page"}
expected_updated_at = datetime(2025, 6, 1, 12, 0, tzinfo=timezone.utc)
assert doc.id == expected_id
assert doc.source == DocumentSource.CANVAS
assert doc.semantic_identifier == "Syllabus"
assert doc.metadata == expected_metadata
assert doc.sections[0].link is not None
assert f"{FAKE_BASE_URL}/courses/1/pages/syllabus" in doc.sections[0].link
assert doc.doc_updated_at == expected_updated_at
def test_convert_page_without_body(self) -> None:
from onyx.connectors.canvas.connector import CanvasPage
page = CanvasPage(
page_id=11,
url="empty-page",
title="Empty Page",
body=None,
created_at="2025-01-15T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
course_id=1,
)
doc = self.connector._convert_page_to_document(page)
section_text = doc.sections[0].text
assert section_text is not None
assert "Empty Page" in section_text
assert "<p>" not in section_text
def test_convert_assignment_to_document(self) -> None:
from onyx.connectors.canvas.connector import CanvasAssignment
assignment = CanvasAssignment(
id=20,
name="Homework 1",
description="<p>Solve these</p>",
html_url=f"{FAKE_BASE_URL}/courses/1/assignments/20",
course_id=1,
created_at="2025-01-20T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
due_at="2025-02-01T23:59:00Z",
)
doc = self.connector._convert_assignment_to_document(assignment)
expected_id = "canvas-assignment-1-20"
expected_due_text = "Due: February 01, 2025 23:59 UTC"
assert doc.id == expected_id
assert doc.source == DocumentSource.CANVAS
assert doc.semantic_identifier == "Homework 1"
assert doc.sections[0].text is not None
assert expected_due_text in doc.sections[0].text
def test_convert_assignment_without_description(self) -> None:
from onyx.connectors.canvas.connector import CanvasAssignment
assignment = CanvasAssignment(
id=21,
name="Quiz 1",
description=None,
html_url=f"{FAKE_BASE_URL}/courses/1/assignments/21",
course_id=1,
created_at="2025-01-20T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
due_at=None,
)
doc = self.connector._convert_assignment_to_document(assignment)
section_text = doc.sections[0].text
assert section_text is not None
assert "Quiz 1" in section_text
assert "Due:" not in section_text
def test_convert_announcement_to_document(self) -> None:
from onyx.connectors.canvas.connector import CanvasAnnouncement
announcement = CanvasAnnouncement(
id=30,
title="Class Cancelled",
message="<p>No class today</p>",
html_url=f"{FAKE_BASE_URL}/courses/1/discussion_topics/30",
posted_at="2025-06-01T12:00:00Z",
course_id=1,
)
doc = self.connector._convert_announcement_to_document(announcement)
expected_id = "canvas-announcement-1-30"
expected_updated_at = datetime(2025, 6, 1, 12, 0, tzinfo=timezone.utc)
assert doc.id == expected_id
assert doc.source == DocumentSource.CANVAS
assert doc.semantic_identifier == "Class Cancelled"
assert doc.doc_updated_at == expected_updated_at
def test_convert_announcement_without_posted_at(self) -> None:
from onyx.connectors.canvas.connector import CanvasAnnouncement
announcement = CanvasAnnouncement(
id=31,
title="TBD Announcement",
message=None,
html_url=f"{FAKE_BASE_URL}/courses/1/discussion_topics/31",
posted_at=None,
course_id=1,
)
doc = self.connector._convert_announcement_to_document(announcement)
assert doc.doc_updated_at is None
# ---------------------------------------------------------------------------
# CanvasConnector — validate_connector_settings
# ---------------------------------------------------------------------------
class TestValidateConnectorSettings:
def _assert_validate_raises(
self,
status_code: int,
expected_error: type[Exception],
mock_requests: MagicMock,
) -> None:
"""Helper: assert validate_connector_settings raises expected_error."""
success_resp = _mock_response(json_data=[_mock_course()])
fail_resp = _mock_response(status_code, {})
mock_requests.get.side_effect = [success_resp, fail_resp]
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
with pytest.raises(expected_error):
connector.validate_connector_settings()
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_success(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[_mock_course()])
connector = _build_connector()
connector.validate_connector_settings() # should not raise
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_expired_credential(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(401, CredentialExpiredError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_insufficient_permissions(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(403, InsufficientPermissionsError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_rate_limited(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(429, ConnectorValidationError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_unexpected_error(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(500, UnexpectedValidationError, mock_requests)
# ---------------------------------------------------------------------------
# _list_* pagination tests
# ---------------------------------------------------------------------------
class TestListCourses:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_course(1), _mock_course(2, "CS201", "Data Structures")]
)
connector = _build_connector()
result = connector._list_courses()
assert len(result) == 2
assert result[0].id == 1
assert result[1].id == 2
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_courses()
assert result == []
class TestListPages:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_page(10), _mock_page(11, "Notes")]
)
connector = _build_connector()
result = connector._list_pages(course_id=1)
assert len(result) == 2
assert result[0].page_id == 10
assert result[1].page_id == 11
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_pages(course_id=1)
assert result == []
class TestListAssignments:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_assignment(20), _mock_assignment(21, "Quiz 1")]
)
connector = _build_connector()
result = connector._list_assignments(course_id=1)
assert len(result) == 2
assert result[0].id == 20
assert result[1].id == 21
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_assignments(course_id=1)
assert result == []
class TestListAnnouncements:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_announcement(30), _mock_announcement(31, "Update")]
)
connector = _build_connector()
result = connector._list_announcements(course_id=1)
assert len(result) == 2
assert result[0].id == 30
assert result[1].id == 31
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_announcements(course_id=1)
assert result == []

View File

@@ -1,45 +0,0 @@
from unittest.mock import AsyncMock
from unittest.mock import patch
import pytest
from discord.errors import LoginFailure
from onyx.connectors.discord.connector import DiscordConnector
from onyx.connectors.exceptions import CredentialInvalidError
def _build_connector(token: str = "fake-bot-token") -> DiscordConnector:
connector = DiscordConnector()
connector.load_credentials({"discord_bot_token": token})
return connector
@patch("onyx.connectors.discord.connector.Client.close", new_callable=AsyncMock)
@patch("onyx.connectors.discord.connector.Client.login", new_callable=AsyncMock)
def test_validate_success(
mock_login: AsyncMock,
mock_close: AsyncMock,
) -> None:
connector = _build_connector()
connector.validate_connector_settings()
mock_login.assert_awaited_once_with("fake-bot-token")
mock_close.assert_awaited_once()
@patch("onyx.connectors.discord.connector.Client.close", new_callable=AsyncMock)
@patch(
"onyx.connectors.discord.connector.Client.login",
new_callable=AsyncMock,
side_effect=LoginFailure("Improper token has been passed."),
)
def test_validate_invalid_token(
mock_login: AsyncMock, # noqa: ARG001
mock_close: AsyncMock,
) -> None:
connector = _build_connector(token="bad-token")
with pytest.raises(CredentialInvalidError, match="Invalid Discord bot token"):
connector.validate_connector_settings()
mock_close.assert_awaited_once()

View File

@@ -1,225 +0,0 @@
"""Tests for get_chat_sessions_by_user filtering behavior.
Verifies that failed chat sessions (those with only SYSTEM messages) are
correctly filtered out while preserving recently created sessions, matching
the behavior specified in PR #7233.
"""
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from unittest.mock import MagicMock
from uuid import UUID
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.models import ChatSession
def _make_session(
user_id: UUID,
time_created: datetime | None = None,
time_updated: datetime | None = None,
description: str = "",
) -> MagicMock:
"""Create a mock ChatSession with the given attributes."""
session = MagicMock(spec=ChatSession)
session.id = uuid4()
session.user_id = user_id
session.time_created = time_created or datetime.now(timezone.utc)
session.time_updated = time_updated or session.time_created
session.description = description
session.deleted = False
session.onyxbot_flow = False
session.project_id = None
return session
@pytest.fixture
def user_id() -> UUID:
return uuid4()
@pytest.fixture
def old_time() -> datetime:
"""A timestamp well outside the 5-minute leeway window."""
return datetime.now(timezone.utc) - timedelta(hours=1)
@pytest.fixture
def recent_time() -> datetime:
"""A timestamp within the 5-minute leeway window."""
return datetime.now(timezone.utc) - timedelta(minutes=2)
class TestGetChatSessionsByUser:
"""Tests for the failed chat filtering logic in get_chat_sessions_by_user."""
def test_filters_out_failed_sessions(
self, user_id: UUID, old_time: datetime
) -> None:
"""Sessions with only SYSTEM messages should be excluded."""
valid_session = _make_session(user_id, time_created=old_time)
failed_session = _make_session(user_id, time_created=old_time)
db_session = MagicMock(spec=Session)
# First execute: returns all sessions
# Second execute: returns only the valid session's ID (has non-system msgs)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [
valid_session,
failed_session,
]
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = [valid_session.id]
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert len(result) == 1
assert result[0].id == valid_session.id
def test_keeps_recent_sessions_without_messages(
self, user_id: UUID, recent_time: datetime
) -> None:
"""Recently created sessions should be kept even without messages."""
recent_session = _make_session(user_id, time_created=recent_time)
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [recent_session]
db_session.execute.side_effect = [mock_result_1]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert len(result) == 1
assert result[0].id == recent_session.id
# Should only have been called once — no second query needed
# because the recent session is within the leeway window
assert db_session.execute.call_count == 1
def test_include_failed_chats_skips_filtering(
self, user_id: UUID, old_time: datetime
) -> None:
"""When include_failed_chats=True, no filtering should occur."""
session_a = _make_session(user_id, time_created=old_time)
session_b = _make_session(user_id, time_created=old_time)
db_session = MagicMock(spec=Session)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [session_a, session_b]
db_session.execute.side_effect = [mock_result]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=True,
)
assert len(result) == 2
# Only one DB call — no second query for message validation
assert db_session.execute.call_count == 1
def test_limit_applied_after_filtering(
self, user_id: UUID, old_time: datetime
) -> None:
"""Limit should be applied after filtering, not before."""
sessions = [_make_session(user_id, time_created=old_time) for _ in range(5)]
valid_ids = [s.id for s in sessions[:3]]
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = sessions
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = valid_ids
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
limit=2,
)
assert len(result) == 2
# Should be the first 2 valid sessions (order preserved)
assert result[0].id == sessions[0].id
assert result[1].id == sessions[1].id
def test_mixed_recent_and_old_sessions(
self, user_id: UUID, old_time: datetime, recent_time: datetime
) -> None:
"""Mix of recent and old sessions should filter correctly."""
old_valid = _make_session(user_id, time_created=old_time)
old_failed = _make_session(user_id, time_created=old_time)
recent_no_msgs = _make_session(user_id, time_created=recent_time)
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [
old_valid,
old_failed,
recent_no_msgs,
]
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = [old_valid.id]
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
result_ids = {cs.id for cs in result}
assert old_valid.id in result_ids
assert recent_no_msgs.id in result_ids
assert old_failed.id not in result_ids
def test_empty_result(self, user_id: UUID) -> None:
"""No sessions should return empty list without errors."""
db_session = MagicMock(spec=Session)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
db_session.execute.side_effect = [mock_result]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert result == []
assert db_session.execute.call_count == 1

View File

@@ -40,8 +40,6 @@ def test_send_task_includes_expires(
user_files=user_files,
rejected_files=[],
id_to_temp_id={},
skip_indexing_filenames=set(),
indexable_files=user_files,
)
mock_user = MagicMock()

View File

@@ -11,13 +11,30 @@ from onyx.hooks.api_dependencies import require_hook_enabled
class TestRequireHookEnabled:
def test_raises_when_multi_tenant(self) -> None:
with patch("onyx.hooks.api_dependencies.MULTI_TENANT", True):
with (
patch("onyx.hooks.api_dependencies.MULTI_TENANT", True),
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", True),
):
with pytest.raises(OnyxError) as exc_info:
require_hook_enabled()
assert exc_info.value.error_code is OnyxErrorCode.SINGLE_TENANT_ONLY
assert exc_info.value.status_code == 403
assert "multi-tenant" in exc_info.value.detail
def test_passes_when_single_tenant(self) -> None:
with patch("onyx.hooks.api_dependencies.MULTI_TENANT", False):
def test_raises_when_flag_disabled(self) -> None:
with (
patch("onyx.hooks.api_dependencies.MULTI_TENANT", False),
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", False),
):
with pytest.raises(OnyxError) as exc_info:
require_hook_enabled()
assert exc_info.value.error_code is OnyxErrorCode.ENV_VAR_GATED
assert exc_info.value.status_code == 403
assert "HOOK_ENABLED" in exc_info.value.detail
def test_passes_when_enabled_single_tenant(self) -> None:
with (
patch("onyx.hooks.api_dependencies.MULTI_TENANT", False),
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", True),
):
require_hook_enabled() # must not raise

View File

@@ -9,11 +9,11 @@ import httpx
import pytest
from pydantic import BaseModel
from ee.onyx.hooks.executor import _execute_hook_impl as execute_hook
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
@@ -118,30 +118,28 @@ def db_session() -> MagicMock:
@pytest.mark.parametrize(
"multi_tenant,hook",
"hooks_available,hook",
[
# MULTI_TENANT=True exits before the DB lookup — hook is irrelevant.
pytest.param(True, None, id="multi_tenant"),
pytest.param(False, None, id="hook_not_found"),
pytest.param(False, _make_hook(is_active=False), id="hook_inactive"),
pytest.param(False, _make_hook(endpoint_url=None), id="no_endpoint_url"),
# HOOKS_AVAILABLE=False exits before the DB lookup — hook is irrelevant.
pytest.param(False, None, id="hooks_not_available"),
pytest.param(True, None, id="hook_not_found"),
pytest.param(True, _make_hook(is_active=False), id="hook_inactive"),
pytest.param(True, _make_hook(endpoint_url=None), id="no_endpoint_url"),
],
)
def test_early_exit_returns_skipped_with_no_db_writes(
db_session: MagicMock,
multi_tenant: bool,
hooks_available: bool,
hook: MagicMock | None,
) -> None:
with (
patch("ee.onyx.hooks.executor.MULTI_TENANT", multi_tenant),
patch("onyx.hooks.executor.HOOKS_AVAILABLE", hooks_available),
patch(
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch(
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
) as mock_log,
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
):
result = execute_hook(
db_session=db_session,
@@ -166,16 +164,14 @@ def test_success_returns_validated_model_and_sets_reachable(
hook = _make_hook()
with (
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch(
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
) as mock_log,
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, response=_make_response())
@@ -199,14 +195,14 @@ def test_success_skips_reachable_write_when_already_true(db_session: MagicMock)
hook = _make_hook(is_reachable=True)
with (
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("ee.onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, response=_make_response())
@@ -228,16 +224,14 @@ def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch(
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
) as mock_log,
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
@@ -264,16 +258,14 @@ def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch(
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
) as mock_log,
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
@@ -392,14 +384,14 @@ def test_http_failure_paths(
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("ee.onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, side_effect=exception)
@@ -451,14 +443,14 @@ def test_authorization_header(
hook = _make_hook(api_key=api_key)
with (
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.update_hook__no_commit"),
patch("ee.onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit"),
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
mock_client = _setup_client(mock_client_cls, response=_make_response())
@@ -497,13 +489,13 @@ def test_persist_session_failure_is_swallowed(
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
with (
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"ee.onyx.hooks.executor.get_session_with_current_tenant",
"onyx.hooks.executor.get_session_with_current_tenant",
side_effect=RuntimeError("DB unavailable"),
),
patch("httpx.Client") as mock_client_cls,
@@ -564,16 +556,14 @@ def test_response_validation_failure_respects_fail_strategy(
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch(
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
) as mock_log,
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
# Response payload is missing required_field → ValidationError
@@ -629,13 +619,13 @@ def test_unexpected_exception_in_inner_respects_fail_strategy(
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"ee.onyx.hooks.executor._execute_hook_inner",
"onyx.hooks.executor._execute_hook_inner",
side_effect=ValueError("unexpected bug"),
),
):
@@ -668,19 +658,17 @@ def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> Non
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch(
"ee.onyx.hooks.executor.update_hook__no_commit",
"onyx.hooks.executor.update_hook__no_commit",
side_effect=OnyxError(OnyxErrorCode.NOT_FOUND, "hook deleted"),
),
patch(
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
) as mock_log,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, side_effect=httpx.ConnectError("refused"))

View File

@@ -1,4 +1,4 @@
"""Unit tests for ee.onyx.server.features.hooks.api helpers.
"""Unit tests for onyx.server.features.hooks.api helpers.
Covers:
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
@@ -16,13 +16,13 @@ from unittest.mock import patch
import httpx
import pytest
from ee.onyx.server.features.hooks.api import _check_ssrf_safety
from ee.onyx.server.features.hooks.api import _raise_for_validation_failure
from ee.onyx.server.features.hooks.api import _validate_endpoint
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.models import HookValidateResponse
from onyx.hooks.models import HookValidateStatus
from onyx.server.features.hooks.api import _check_ssrf_safety
from onyx.server.features.hooks.api import _raise_for_validation_failure
from onyx.server.features.hooks.api import _validate_endpoint
# ---------------------------------------------------------------------------
# Helpers
@@ -117,28 +117,28 @@ class TestCheckSsrfSafety:
class TestValidateEndpoint:
def _call(self, *, api_key: str | None = _API_KEY) -> HookValidateResponse:
# Bypass SSRF check — tested separately in TestCheckSsrfSafety.
with patch("ee.onyx.server.features.hooks.api._check_ssrf_safety"):
with patch("onyx.server.features.hooks.api._check_ssrf_safety"):
return _validate_endpoint(
endpoint_url=_URL,
api_key=api_key,
timeout_seconds=_TIMEOUT,
)
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_2xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(200)
)
assert self._call().status == HookValidateStatus.passed
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_5xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(500)
)
assert self._call().status == HookValidateStatus.passed
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize("status_code", [401, 403])
def test_401_403_returns_auth_failed(
self, mock_client_cls: MagicMock, status_code: int
@@ -150,21 +150,21 @@ class TestValidateEndpoint:
assert result.status == HookValidateStatus.auth_failed
assert str(status_code) in (result.error_message or "")
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_4xx_non_auth_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(422)
)
assert self._call().status == HookValidateStatus.passed
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectTimeout("timed out")
)
assert self._call().status == HookValidateStatus.timeout
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize(
"exc",
[
@@ -179,7 +179,7 @@ class TestValidateEndpoint:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = exc
assert self._call().status == HookValidateStatus.timeout
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_error_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
@@ -189,7 +189,7 @@ class TestValidateEndpoint:
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_arbitrary_exception_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
@@ -198,7 +198,7 @@ class TestValidateEndpoint:
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_api_key_sent_as_bearer(self, mock_client_cls: MagicMock) -> None:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = _mock_response(200)
@@ -206,7 +206,7 @@ class TestValidateEndpoint:
_, kwargs = mock_post.call_args
assert kwargs["headers"]["Authorization"] == "Bearer mykey"
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_no_api_key_omits_auth_header(self, mock_client_cls: MagicMock) -> None:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = _mock_response(200)

View File

@@ -417,57 +417,3 @@ def test_categorize_text_under_token_limit_accepted(
assert len(result.acceptable) == 1
assert result.acceptable_file_to_token_count["ok.txt"] == 500
# --- skip-indexing vs rejection by file type ---
def test_csv_over_token_threshold_accepted_skip_indexing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""CSV exceeding token threshold is uploaded but flagged to skip indexing."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
text = "x" * 2000 # 2000 tokens > 1000 threshold
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: text)
upload = _make_upload("large.csv", size=2000, content=text.encode())
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert result.acceptable[0].filename == "large.csv"
assert "large.csv" in result.skip_indexing
assert len(result.rejected) == 0
def test_csv_under_token_threshold_accepted_and_indexed(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""CSV under token threshold is uploaded and indexed normally."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
text = "x" * 500 # 500 tokens < 1000 threshold
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: text)
upload = _make_upload("small.csv", size=500, content=text.encode())
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert result.acceptable[0].filename == "small.csv"
assert "small.csv" not in result.skip_indexing
assert len(result.rejected) == 0
def test_pdf_over_token_threshold_rejected(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""PDF exceeding token threshold is rejected entirely (not uploaded)."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
text = "x" * 2000 # 2000 tokens > 1000 threshold
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: text)
upload = _make_upload("big.pdf", size=2000, content=text.encode())
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.rejected) == 1
assert result.rejected[0].filename == "big.pdf"
assert "1K token limit" in result.rejected[0].reason
assert len(result.acceptable) == 0

View File

@@ -82,7 +82,7 @@ class TestChatFileConversion:
ChatLoadedFile(
file_id="file-2",
content=b"csv,data\n1,2",
file_type=ChatFileType.TABULAR,
file_type=ChatFileType.CSV,
filename="data.csv",
content_text="csv,data\n1,2",
token_count=5,

View File

@@ -203,7 +203,6 @@ prompt_or_default() {
local default_value="$2"
read_prompt_line "$prompt_text"
[[ -z "$REPLY" ]] && REPLY="$default_value"
return 0
}
prompt_yn_or_default() {
@@ -211,7 +210,6 @@ prompt_yn_or_default() {
local default_value="$2"
read_prompt_char "$prompt_text"
[[ -z "$REPLY" ]] && REPLY="$default_value"
return 0
}
confirm_action() {

View File

@@ -5,7 +5,7 @@ home: https://www.onyx.app/
sources:
- "https://github.com/onyx-dot-app/onyx"
type: application
version: 0.4.39
version: 0.4.38
appVersion: latest
annotations:
category: Productivity

File diff suppressed because it is too large Load Diff

View File

@@ -1,77 +0,0 @@
{{- if and .Values.monitoring.serviceMonitors.enabled .Values.vectorDB.enabled }}
{{- if gt (int .Values.celery_worker_monitoring.replicaCount) 0 }}
---
apiVersion: monitoring.coreos.com/v1
kind: ServiceMonitor
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-monitoring
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- with .Values.monitoring.serviceMonitors.labels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
namespaceSelector:
matchNames:
- {{ .Release.Namespace }}
selector:
matchLabels:
app: {{ .Values.celery_worker_monitoring.deploymentLabels.app }}
metrics: "true"
endpoints:
- port: metrics
path: /metrics
interval: 30s
scrapeTimeout: 10s
{{- end }}
{{- if gt (int .Values.celery_worker_docfetching.replicaCount) 0 }}
---
apiVersion: monitoring.coreos.com/v1
kind: ServiceMonitor
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-docfetching
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- with .Values.monitoring.serviceMonitors.labels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
namespaceSelector:
matchNames:
- {{ .Release.Namespace }}
selector:
matchLabels:
app: {{ .Values.celery_worker_docfetching.deploymentLabels.app }}
metrics: "true"
endpoints:
- port: metrics
path: /metrics
interval: 30s
scrapeTimeout: 10s
{{- end }}
{{- if gt (int .Values.celery_worker_docprocessing.replicaCount) 0 }}
---
apiVersion: monitoring.coreos.com/v1
kind: ServiceMonitor
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-docprocessing
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- with .Values.monitoring.serviceMonitors.labels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
namespaceSelector:
matchNames:
- {{ .Release.Namespace }}
selector:
matchLabels:
app: {{ .Values.celery_worker_docprocessing.deploymentLabels.app }}
metrics: "true"
endpoints:
- port: metrics
path: /metrics
interval: 30s
scrapeTimeout: 10s
{{- end }}
{{- end }}

View File

@@ -1,15 +0,0 @@
{{- if .Values.monitoring.grafana.dashboards.enabled }}
---
apiVersion: v1
kind: ConfigMap
metadata:
name: {{ include "onyx.fullname" . }}-indexing-pipeline-dashboard
labels:
{{- include "onyx.labels" . | nindent 4 }}
grafana_dashboard: "1"
annotations:
grafana_folder: "Onyx"
data:
onyx-indexing-pipeline.json: |
{{- .Files.Get "dashboards/indexing-pipeline.json" | nindent 4 }}
{{- end }}

View File

@@ -256,20 +256,6 @@ tooling:
# -- Which client binary to call; change if your image uses a non-default path.
psqlBinary: psql
monitoring:
grafana:
dashboards:
# -- Set to true to deploy Grafana dashboard ConfigMaps for the Onyx indexing pipeline.
# Requires kube-prometheus-stack (or equivalent) with the Grafana sidecar enabled and watching this namespace.
# The sidecar must be configured with label selector: grafana_dashboard=1
enabled: false
serviceMonitors:
# -- Set to true to deploy ServiceMonitor resources for Celery worker metrics endpoints.
# Requires the Prometheus Operator CRDs (included in kube-prometheus-stack).
# Use `labels` to match your Prometheus CR's serviceMonitorSelector (e.g. release: onyx-monitoring).
enabled: false
labels: {}
serviceAccount:
# Specifies whether a service account should be created
create: false

View File

@@ -19,10 +19,6 @@ module "eks" {
cluster_endpoint_public_access_cidrs = var.cluster_endpoint_public_access_cidrs
enable_cluster_creator_admin_permissions = true
# Control plane logging
cluster_enabled_log_types = var.cluster_enabled_log_types
cloudwatch_log_group_retention_in_days = var.cloudwatch_log_group_retention_in_days
eks_managed_node_group_defaults = {
ami_type = "AL2023_x86_64_STANDARD"
}

View File

@@ -161,25 +161,3 @@ variable "rds_db_connect_arn" {
description = "Full rds-db:connect ARN to allow (required when enable_rds_iam_for_service_account is true)"
default = null
}
variable "cluster_enabled_log_types" {
type = list(string)
description = "EKS control plane log types to enable (valid: api, audit, authenticator, controllerManager, scheduler)"
default = ["api", "audit", "authenticator", "controllerManager", "scheduler"]
validation {
condition = alltrue([for t in var.cluster_enabled_log_types : contains(["api", "audit", "authenticator", "controllerManager", "scheduler"], t)])
error_message = "Each entry must be one of: api, audit, authenticator, controllerManager, scheduler."
}
}
variable "cloudwatch_log_group_retention_in_days" {
type = number
description = "Number of days to retain EKS control plane logs in CloudWatch (0 = never expire)"
default = 30
validation {
condition = contains([0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653], var.cloudwatch_log_group_retention_in_days)
error_message = "Must be a valid CloudWatch retention value (0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653)."
}
}

View File

@@ -54,9 +54,6 @@ module "postgres" {
password = var.postgres_password
tags = local.merged_tags
enable_rds_iam_auth = var.enable_iam_auth
backup_retention_period = var.postgres_backup_retention_period
backup_window = var.postgres_backup_window
}
module "s3" {
@@ -83,10 +80,6 @@ module "eks" {
public_cluster_enabled = var.public_cluster_enabled
private_cluster_enabled = var.private_cluster_enabled
cluster_endpoint_public_access_cidrs = var.cluster_endpoint_public_access_cidrs
# Control plane logging
cluster_enabled_log_types = var.eks_cluster_enabled_log_types
cloudwatch_log_group_retention_in_days = var.eks_cloudwatch_log_group_retention_in_days
}
module "waf" {

View File

@@ -250,34 +250,3 @@ variable "opensearch_subnet_ids" {
description = "Subnet IDs for OpenSearch. If empty, uses first 3 private subnets."
default = []
}
# RDS Backup Configuration
variable "postgres_backup_retention_period" {
type = number
description = "Number of days to retain automated RDS backups (0 to disable)"
default = 7
}
variable "postgres_backup_window" {
type = string
description = "Preferred UTC time window for automated RDS backups (hh24:mi-hh24:mi)"
default = "03:00-04:00"
}
# EKS Control Plane Logging
variable "eks_cluster_enabled_log_types" {
type = list(string)
description = "EKS control plane log types to enable (valid: api, audit, authenticator, controllerManager, scheduler)"
default = ["api", "audit", "authenticator", "controllerManager", "scheduler"]
}
variable "eks_cloudwatch_log_group_retention_in_days" {
type = number
description = "Number of days to retain EKS control plane logs in CloudWatch (0 = never expire)"
default = 30
validation {
condition = contains([0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653], var.eks_cloudwatch_log_group_retention_in_days)
error_message = "Must be a valid CloudWatch retention value (0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653)."
}
}

View File

@@ -44,79 +44,5 @@ resource "aws_db_instance" "this" {
publicly_accessible = false
deletion_protection = true
storage_encrypted = true
# Automated backups
backup_retention_period = var.backup_retention_period
backup_window = var.backup_window
tags = var.tags
}
# CloudWatch alarm for CPU utilization monitoring
resource "aws_cloudwatch_metric_alarm" "cpu_utilization" {
alarm_name = "${var.identifier}-cpu-utilization"
alarm_description = "RDS CPU utilization for ${var.identifier}"
comparison_operator = "GreaterThanThreshold"
evaluation_periods = var.cpu_alarm_evaluation_periods
metric_name = "CPUUtilization"
namespace = "AWS/RDS"
period = var.cpu_alarm_period
statistic = "Average"
threshold = var.cpu_alarm_threshold
treat_missing_data = "missing"
alarm_actions = var.alarm_actions
ok_actions = var.alarm_actions
dimensions = {
DBInstanceIdentifier = aws_db_instance.this.identifier
}
tags = var.tags
}
# CloudWatch alarm for disk IO monitoring
resource "aws_cloudwatch_metric_alarm" "read_iops" {
alarm_name = "${var.identifier}-read-iops"
alarm_description = "RDS ReadIOPS for ${var.identifier}"
comparison_operator = "GreaterThanThreshold"
evaluation_periods = var.iops_alarm_evaluation_periods
metric_name = "ReadIOPS"
namespace = "AWS/RDS"
period = var.iops_alarm_period
statistic = "Average"
threshold = var.read_iops_alarm_threshold
treat_missing_data = "missing"
alarm_actions = var.alarm_actions
ok_actions = var.alarm_actions
dimensions = {
DBInstanceIdentifier = aws_db_instance.this.identifier
}
tags = var.tags
}
# CloudWatch alarm for freeable memory monitoring
resource "aws_cloudwatch_metric_alarm" "freeable_memory" {
alarm_name = "${var.identifier}-freeable-memory"
alarm_description = "RDS freeable memory for ${var.identifier}"
comparison_operator = "LessThanThreshold"
evaluation_periods = var.memory_alarm_evaluation_periods
metric_name = "FreeableMemory"
namespace = "AWS/RDS"
period = var.memory_alarm_period
statistic = "Average"
threshold = var.memory_alarm_threshold
treat_missing_data = "missing"
alarm_actions = var.alarm_actions
ok_actions = var.alarm_actions
dimensions = {
DBInstanceIdentifier = aws_db_instance.this.identifier
}
tags = var.tags
tags = var.tags
}

View File

@@ -67,131 +67,3 @@ variable "enable_rds_iam_auth" {
description = "Enable AWS IAM database authentication for this RDS instance"
default = false
}
variable "backup_retention_period" {
type = number
description = "Number of days to retain automated backups (0 to disable)"
default = 7
validation {
condition = var.backup_retention_period >= 0 && var.backup_retention_period <= 35
error_message = "backup_retention_period must be between 0 and 35 (AWS RDS limit)."
}
}
variable "backup_window" {
type = string
description = "Preferred UTC time window for automated backups (hh24:mi-hh24:mi)"
default = "03:00-04:00"
validation {
condition = can(regex("^([01]\\d|2[0-3]):[0-5]\\d-([01]\\d|2[0-3]):[0-5]\\d$", var.backup_window))
error_message = "backup_window must be in hh24:mi-hh24:mi format (e.g. \"03:00-04:00\")."
}
}
# CloudWatch CPU alarm configuration
variable "cpu_alarm_threshold" {
type = number
description = "CPU utilization percentage threshold for the CloudWatch alarm"
default = 80
validation {
condition = var.cpu_alarm_threshold >= 0 && var.cpu_alarm_threshold <= 100
error_message = "cpu_alarm_threshold must be between 0 and 100 (percentage)."
}
}
variable "cpu_alarm_evaluation_periods" {
type = number
description = "Number of consecutive periods the threshold must be breached before alarming"
default = 3
validation {
condition = var.cpu_alarm_evaluation_periods >= 1
error_message = "cpu_alarm_evaluation_periods must be at least 1."
}
}
variable "cpu_alarm_period" {
type = number
description = "Period in seconds over which the CPU metric is evaluated"
default = 300
validation {
condition = var.cpu_alarm_period >= 60 && var.cpu_alarm_period % 60 == 0
error_message = "cpu_alarm_period must be a multiple of 60 seconds and at least 60 (CloudWatch requirement)."
}
}
variable "memory_alarm_threshold" {
type = number
description = "Freeable memory threshold in bytes. Alarm fires when memory drops below this value."
default = 256000000 # 256 MB
validation {
condition = var.memory_alarm_threshold > 0
error_message = "memory_alarm_threshold must be greater than 0."
}
}
variable "memory_alarm_evaluation_periods" {
type = number
description = "Number of consecutive periods the threshold must be breached before alarming"
default = 3
validation {
condition = var.memory_alarm_evaluation_periods >= 1
error_message = "memory_alarm_evaluation_periods must be at least 1."
}
}
variable "memory_alarm_period" {
type = number
description = "Period in seconds over which the freeable memory metric is evaluated"
default = 300
validation {
condition = var.memory_alarm_period >= 60 && var.memory_alarm_period % 60 == 0
error_message = "memory_alarm_period must be a multiple of 60 seconds and at least 60 (CloudWatch requirement)."
}
}
variable "read_iops_alarm_threshold" {
type = number
description = "ReadIOPS threshold. Alarm fires when IOPS exceeds this value."
default = 3000
validation {
condition = var.read_iops_alarm_threshold > 0
error_message = "read_iops_alarm_threshold must be greater than 0."
}
}
variable "iops_alarm_evaluation_periods" {
type = number
description = "Number of consecutive periods the IOPS threshold must be breached before alarming"
default = 3
validation {
condition = var.iops_alarm_evaluation_periods >= 1
error_message = "iops_alarm_evaluation_periods must be at least 1."
}
}
variable "iops_alarm_period" {
type = number
description = "Period in seconds over which the IOPS metric is evaluated"
default = 300
validation {
condition = var.iops_alarm_period >= 60 && var.iops_alarm_period % 60 == 0
error_message = "iops_alarm_period must be a multiple of 60 seconds and at least 60 (CloudWatch requirement)."
}
}
variable "alarm_actions" {
type = list(string)
description = "List of ARNs to notify when the alarm transitions state (e.g. SNS topic ARNs)"
default = []
}

View File

@@ -1,349 +0,0 @@
{
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": { "type": "grafana", "uid": "-- Grafana --" },
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"type": "dashboard"
}
]
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 1,
"id": null,
"links": [],
"liveNow": true,
"panels": [
{
"title": "Client-Side Search Latency (P50 / P95 / P99)",
"description": "End-to-end latency as measured by the Python client, including network round-trip and serialization overhead.",
"type": "timeseries",
"gridPos": { "h": 10, "w": 12, "x": 0, "y": 0 },
"id": 1,
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisLabel": "seconds",
"axisPlacement": "auto",
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "never",
"spanNulls": false,
"stacking": { "group": "A", "mode": "none" },
"thresholdsStyle": { "mode": "dashed" }
},
"thresholds": {
"mode": "absolute",
"steps": [
{ "color": "green", "value": null },
{ "color": "yellow", "value": 0.5 },
{ "color": "red", "value": 2.0 }
]
},
"unit": "s",
"min": 0
},
"overrides": []
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
"legendFormat": "P50",
"refId": "A"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.95, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
"legendFormat": "P95",
"refId": "B"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.99, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
"legendFormat": "P99",
"refId": "C"
}
]
},
{
"title": "Server-Side Search Latency (P50 / P95 / P99)",
"description": "OpenSearch server-side execution time from the 'took' field in the response. Does not include network or client-side overhead.",
"type": "timeseries",
"gridPos": { "h": 10, "w": 12, "x": 12, "y": 0 },
"id": 2,
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisLabel": "seconds",
"axisPlacement": "auto",
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "never",
"spanNulls": false,
"stacking": { "group": "A", "mode": "none" },
"thresholdsStyle": { "mode": "dashed" }
},
"thresholds": {
"mode": "absolute",
"steps": [
{ "color": "green", "value": null },
{ "color": "yellow", "value": 0.5 },
{ "color": "red", "value": 2.0 }
]
},
"unit": "s",
"min": 0
},
"overrides": []
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
"legendFormat": "P50",
"refId": "A"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.95, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
"legendFormat": "P95",
"refId": "B"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.99, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
"legendFormat": "P99",
"refId": "C"
}
]
},
{
"title": "Client-Side Latency by Search Type (P95)",
"description": "P95 client-side latency broken down by search type (hybrid, keyword, semantic, random, doc_id_retrieval).",
"type": "timeseries",
"gridPos": { "h": 10, "w": 12, "x": 0, "y": 10 },
"id": 3,
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisLabel": "seconds",
"axisPlacement": "auto",
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "never",
"spanNulls": false,
"stacking": { "group": "A", "mode": "none" },
"thresholdsStyle": { "mode": "off" }
},
"unit": "s",
"min": 0
},
"overrides": []
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.95, sum by (search_type, le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
"legendFormat": "{{ search_type }}",
"refId": "A"
}
]
},
{
"title": "Search Throughput by Type",
"description": "Searches per second broken down by search type.",
"type": "timeseries",
"gridPos": { "h": 10, "w": 12, "x": 12, "y": 10 },
"id": 4,
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisLabel": "searches/s",
"axisPlacement": "auto",
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "never",
"spanNulls": false,
"stacking": { "group": "A", "mode": "normal" },
"thresholdsStyle": { "mode": "off" }
},
"unit": "ops",
"min": 0
},
"overrides": []
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum by (search_type) (rate(onyx_opensearch_search_total[5m]))",
"legendFormat": "{{ search_type }}",
"refId": "A"
}
]
},
{
"title": "Concurrent Searches In Progress",
"description": "Number of OpenSearch searches currently in flight, broken down by search type. Summed across all instances.",
"type": "timeseries",
"gridPos": { "h": 10, "w": 12, "x": 0, "y": 20 },
"id": 5,
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisLabel": "searches",
"axisPlacement": "auto",
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "never",
"spanNulls": false,
"stacking": { "group": "A", "mode": "normal" },
"thresholdsStyle": { "mode": "off" }
},
"min": 0
},
"overrides": []
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum by (search_type) (onyx_opensearch_searches_in_progress)",
"legendFormat": "{{ search_type }}",
"refId": "A"
}
]
},
{
"title": "Client vs Server Latency Overhead (P50)",
"description": "Difference between client-side and server-side P50 latency. Reveals network, serialization, and untracked OpenSearch overhead.",
"type": "timeseries",
"gridPos": { "h": 10, "w": 12, "x": 12, "y": 20 },
"id": 6,
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisLabel": "seconds",
"axisPlacement": "auto",
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "never",
"spanNulls": false,
"stacking": { "group": "A", "mode": "none" },
"thresholdsStyle": { "mode": "off" }
},
"unit": "s",
"min": 0
},
"overrides": []
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m]))) - histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
"legendFormat": "Client - Server overhead (P50)",
"refId": "A"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
"legendFormat": "Client P50",
"refId": "B"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
"legendFormat": "Server P50",
"refId": "C"
}
]
}
],
"refresh": "5s",
"schemaVersion": 37,
"style": "dark",
"tags": ["onyx", "opensearch", "search", "latency"],
"templating": {
"list": [
{
"current": {
"text": "Prometheus",
"value": "prometheus"
},
"includeAll": false,
"name": "DS_PROMETHEUS",
"options": [],
"query": "prometheus",
"refresh": 1,
"type": "datasource"
}
]
},
"time": { "from": "now-60m", "to": "now" },
"timepicker": {
"refresh_intervals": ["5s", "10s", "30s", "1m"]
},
"timezone": "",
"title": "Onyx OpenSearch Search Latency",
"uid": "onyx-opensearch-search-latency",
"version": 0,
"weekStart": ""
}

View File

@@ -73,17 +73,11 @@ ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
ARG NEXT_PUBLIC_RECAPTCHA_SITE_KEY
ENV NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${NEXT_PUBLIC_RECAPTCHA_SITE_KEY}
ARG SENTRY_RELEASE
ENV SENTRY_RELEASE=${SENTRY_RELEASE}
# Add NODE_OPTIONS argument
ARG NODE_OPTIONS
# SENTRY_AUTH_TOKEN is injected via BuildKit secret mount so it is never written
# to any image layer, build cache, or registry manifest.
# Use NODE_OPTIONS in the build command
RUN --mount=type=secret,id=sentry_auth_token,env=SENTRY_AUTH_TOKEN \
NODE_OPTIONS="${NODE_OPTIONS}" npx next build
RUN NODE_OPTIONS="${NODE_OPTIONS}" npx next build
# Step 2. Production image, copy all the files and run next
FROM base AS runner
@@ -156,9 +150,6 @@ ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
ARG NEXT_PUBLIC_RECAPTCHA_SITE_KEY
ENV NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${NEXT_PUBLIC_RECAPTCHA_SITE_KEY}
ARG SENTRY_RELEASE
ENV SENTRY_RELEASE=${SENTRY_RELEASE}
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.0.0-dev
ENV ONYX_VERSION=${ONYX_VERSION}

View File

@@ -100,7 +100,9 @@ function Button({
border={interactiveProps.prominence === "secondary"}
heightVariant={size}
widthVariant={width}
roundingVariant={isLarge ? "md" : size === "2xs" ? "xs" : "sm"}
roundingVariant={
isLarge ? "default" : size === "2xs" ? "mini" : "compact"
}
>
<div className="flex flex-row items-center gap-1">
{iconWrapper(Icon, size, !!children)}

View File

@@ -35,7 +35,7 @@ Interactive.Stateful <- selectVariant, state, interaction, onClick, href
| Prop | Type | Default | Description |
|------|------|---------|-------------|
| `roundingVariant` | `InteractiveContainerRoundingVariant` | `"md"` | Corner rounding preset (height is content-driven) |
| `roundingVariant` | `InteractiveContainerRoundingVariant` | `"default"` | Corner rounding preset (height is content-driven) |
| `width` | `WidthVariant` | `"full"` | Container width |
| `type` | `"submit" \| "button" \| "reset"` | `"button"` | HTML button type |
| `tooltip` | `string` | — | Tooltip text shown on hover |
@@ -63,7 +63,7 @@ import { LineItemButton } from "@opal/components";
<LineItemButton
selectVariant="select-heavy"
state={isSelected ? "selected" : "empty"}
roundingVariant="sm"
roundingVariant="compact"
onClick={handleClick}
title="gpt-4o"
sizePreset="main-ui"

View File

@@ -33,7 +33,7 @@ type LineItemButtonOwnProps = Pick<
/** Interactive select variant. @default "select-light" */
selectVariant?: "select-light" | "select-heavy";
/** Corner rounding preset (height is always content-driven). @default "md" */
/** Corner rounding preset (height is always content-driven). @default "default" */
roundingVariant?: InteractiveContainerRoundingVariant;
/** Container width. @default "full" */
@@ -65,7 +65,7 @@ function LineItemButton({
type = "button",
// Sizing
roundingVariant = "md",
roundingVariant = "default",
width = "full",
tooltip,
tooltipSide = "top",

View File

@@ -127,7 +127,7 @@ function OpenButton({
widthVariant={width}
roundingVariant={
roundingVariantOverride ??
(isLarge ? "md" : size === "2xs" ? "xs" : "sm")
(isLarge ? "default" : size === "2xs" ? "mini" : "compact")
}
>
<div

View File

@@ -101,7 +101,9 @@ function SelectButton({
type={type}
heightVariant={size}
widthVariant={width}
roundingVariant={isLarge ? "md" : size === "2xs" ? "xs" : "sm"}
roundingVariant={
isLarge ? "default" : size === "2xs" ? "mini" : "compact"
}
>
<div
className={cn(

View File

@@ -3,8 +3,7 @@ import { Card } from "@opal/components";
const BACKGROUND_VARIANTS = ["none", "light", "heavy"] as const;
const BORDER_VARIANTS = ["none", "dashed", "solid"] as const;
const PADDING_VARIANTS = ["fit", "2xs", "xs", "sm", "md", "lg"] as const;
const ROUNDING_VARIANTS = ["xs", "sm", "md", "lg"] as const;
const SIZE_VARIANTS = ["lg", "md", "sm", "xs", "2xs", "fit"] as const;
const meta: Meta<typeof Card> = {
title: "opal/components/Card",
@@ -18,9 +17,7 @@ type Story = StoryObj<typeof Card>;
export const Default: Story = {
render: () => (
<Card>
<p>
Default card with light background, no border, sm padding, md rounding.
</p>
<p>Default card with light background, no border, lg size.</p>
</Card>
),
};
@@ -49,24 +46,12 @@ export const BorderVariants: Story = {
),
};
export const PaddingVariants: Story = {
export const SizeVariants: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
{PADDING_VARIANTS.map((padding) => (
<Card key={padding} paddingVariant={padding} borderVariant="solid">
<p>paddingVariant: {padding}</p>
</Card>
))}
</div>
),
};
export const RoundingVariants: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
{ROUNDING_VARIANTS.map((rounding) => (
<Card key={rounding} roundingVariant={rounding} borderVariant="solid">
<p>roundingVariant: {rounding}</p>
{SIZE_VARIANTS.map((size) => (
<Card key={size} sizeVariant={size} borderVariant="solid">
<p>sizeVariant: {size}</p>
</Card>
))}
</div>
@@ -76,15 +61,15 @@ export const RoundingVariants: Story = {
export const AllCombinations: Story = {
render: () => (
<div className="flex flex-col gap-8">
{PADDING_VARIANTS.map((padding) => (
<div key={padding}>
<p className="font-bold pb-2">paddingVariant: {padding}</p>
{SIZE_VARIANTS.map((size) => (
<div key={size}>
<p className="font-bold pb-2">sizeVariant: {size}</p>
<div className="grid grid-cols-3 gap-4">
{BACKGROUND_VARIANTS.map((bg) =>
BORDER_VARIANTS.map((border) => (
<Card
key={`${padding}-${bg}-${border}`}
paddingVariant={padding}
key={`${size}-${bg}-${border}`}
sizeVariant={size}
backgroundVariant={bg}
borderVariant={border}
>

View File

@@ -6,53 +6,52 @@ A plain container component with configurable background, border, padding, and r
## Architecture
Padding and rounding are controlled independently:
The `sizeVariant` controls both padding and border-radius, mirroring the same mapping used by `Button` and `Interactive.Container`:
| `paddingVariant` | Class |
|------------------|---------|
| `"lg"` | `p-6` |
| `"md"` | `p-4` |
| `"sm"` | `p-2` |
| `"xs"` | `p-1` |
| `"2xs"` | `p-0.5` |
| `"fit"` | `p-0` |
| `roundingVariant` | Class |
|-------------------|--------------|
| `"xs"` | `rounded-04` |
| `"sm"` | `rounded-08` |
| `"md"` | `rounded-12` |
| `"lg"` | `rounded-16` |
| Size | Padding | Rounding |
|-----------|---------|----------------|
| `lg` | `p-2` | `rounded-12` |
| `md` | `p-1` | `rounded-08` |
| `sm` | `p-1` | `rounded-08` |
| `xs` | `p-0.5` | `rounded-04` |
| `2xs` | `p-0.5` | `rounded-04` |
| `fit` | `p-0` | `rounded-12` |
## Props
| Prop | Type | Default | Description |
|------|------|---------|-------------|
| `paddingVariant` | `PaddingVariants` | `"sm"` | Padding preset |
| `roundingVariant` | `RoundingVariants` | `"md"` | Border-radius preset |
| `sizeVariant` | `SizeVariant` | `"lg"` | Controls padding and border-radius |
| `backgroundVariant` | `"none" \| "light" \| "heavy"` | `"light"` | Background fill intensity |
| `borderVariant` | `"none" \| "dashed" \| "solid"` | `"none"` | Border style |
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
| `children` | `React.ReactNode` | — | Card content |
## Background Variants
- **`none`** — Transparent background. Use for seamless inline content.
- **`light`** — Subtle tinted background (`bg-background-tint-00`). The default, suitable for most cards.
- **`heavy`** — Stronger tinted background (`bg-background-tint-01`). Use for emphasis or nested cards that need visual separation.
## Border Variants
- **`none`** — No border. Use when cards are visually grouped or in tight layouts.
- **`dashed`** — Dashed border. Use for placeholder or empty states.
- **`solid`** — Solid border. Use for prominent, standalone cards.
## Usage
```tsx
import { Card } from "@opal/components";
// Default card (light background, no border, sm padding, md rounding)
// Default card (light background, no border, lg padding + rounding)
<Card>
<h2>Card Title</h2>
<p>Card content</p>
</Card>
// Large padding + rounding with solid border
<Card paddingVariant="lg" roundingVariant="lg" borderVariant="solid">
<p>Spacious card</p>
</Card>
// Compact card with solid border
<Card paddingVariant="xs" roundingVariant="sm" borderVariant="solid">
<Card borderVariant="solid" sizeVariant="sm">
<p>Compact card</p>
</Card>
@@ -60,4 +59,9 @@ import { Card } from "@opal/components";
<Card backgroundVariant="none" borderVariant="dashed">
<p>No items yet</p>
</Card>
// Heavy background, tight padding
<Card backgroundVariant="heavy" sizeVariant="xs">
<p>Highlighted content</p>
</Card>
```

View File

@@ -1,5 +1,6 @@
import "@opal/components/cards/card/styles.css";
import type { PaddingVariants, RoundingVariants } from "@opal/types";
import type { ContainerSizeVariants } from "@opal/types";
import { containerSizeVariants } from "@opal/shared";
import { cn } from "@opal/utils";
// ---------------------------------------------------------------------------
@@ -11,34 +12,21 @@ type BorderVariant = "none" | "dashed" | "solid";
type CardProps = {
/**
* Padding preset.
* Size preset — controls padding and border-radius.
*
* | Value | Class |
* |---------|---------|
* | `"lg"` | `p-6` |
* | `"md"` | `p-4` |
* | `"sm"` | `p-2` |
* | `"xs"` | `p-1` |
* | `"2xs"` | `p-0.5` |
* | `"fit"` | `p-0` |
* Padding comes from the shared size scale. Rounding follows the same
* mapping as `Button` / `Interactive.Container`:
*
* @default "sm"
* | Size | Rounding |
* |--------|------------|
* | `lg` | `default` |
* | `md``sm` | `compact` |
* | `xs``2xs` | `mini` |
* | `fit` | `default` |
*
* @default "lg"
*/
paddingVariant?: PaddingVariants;
/**
* Border-radius preset.
*
* | Value | Class |
* |--------|--------------|
* | `"xs"` | `rounded-04` |
* | `"sm"` | `rounded-08` |
* | `"md"` | `rounded-12` |
* | `"lg"` | `rounded-16` |
*
* @default "md"
*/
roundingVariant?: RoundingVariants;
sizeVariant?: ContainerSizeVariants;
/**
* Background fill intensity.
@@ -67,23 +55,17 @@ type CardProps = {
};
// ---------------------------------------------------------------------------
// Mappings
// Rounding
// ---------------------------------------------------------------------------
const paddingForVariant: Record<PaddingVariants, string> = {
lg: "p-6",
md: "p-4",
sm: "p-2",
xs: "p-1",
"2xs": "p-0.5",
fit: "p-0",
};
const roundingForVariant: Record<RoundingVariants, string> = {
lg: "rounded-16",
md: "rounded-12",
/** Maps a size variant to a rounding class, mirroring the Button pattern. */
const roundingForSize: Record<ContainerSizeVariants, string> = {
lg: "rounded-12",
md: "rounded-08",
sm: "rounded-08",
xs: "rounded-04",
"2xs": "rounded-04",
fit: "rounded-12",
};
// ---------------------------------------------------------------------------
@@ -91,15 +73,14 @@ const roundingForVariant: Record<RoundingVariants, string> = {
// ---------------------------------------------------------------------------
function Card({
paddingVariant = "sm",
roundingVariant = "md",
sizeVariant = "lg",
backgroundVariant = "light",
borderVariant = "none",
ref,
children,
}: CardProps) {
const padding = paddingForVariant[paddingVariant];
const rounding = roundingForVariant[roundingVariant];
const { padding } = containerSizeVariants[sizeVariant];
const rounding = roundingForSize[sizeVariant];
return (
<div

View File

@@ -6,12 +6,12 @@ A pre-configured Card for empty states. Renders a transparent card with a dashed
## Props
| Prop | Type | Default | Description |
| ----------------- | --------------------------- | ---------- | ------------------------------------------------ |
| `icon` | `IconFunctionComponent` | `SvgEmpty` | Icon displayed alongside the title |
| `title` | `string` | — | Primary message text (required) |
| `paddingVariant` | `PaddingVariants` | `"sm"` | Padding preset for the card |
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
| Prop | Type | Default | Description |
| ------------- | -------------------------- | ---------- | ------------------------------------------------ |
| `icon` | `IconFunctionComponent` | `SvgEmpty` | Icon displayed alongside the title |
| `title` | `string` | — | Primary message text (required) |
| `sizeVariant` | `SizeVariant` | `"lg"` | Size preset controlling padding and rounding |
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
## Usage
@@ -25,6 +25,6 @@ import { SvgSparkle, SvgFileText } from "@opal/icons";
// With custom icon
<EmptyMessageCard icon={SvgSparkle} title="No agents selected." />
// With custom padding
<EmptyMessageCard paddingVariant="xs" icon={SvgFileText} title="No documents available." />
// With custom size
<EmptyMessageCard sizeVariant="sm" icon={SvgFileText} title="No documents available." />
```

View File

@@ -1,7 +1,8 @@
import { Card } from "@opal/components/cards/card/components";
import { Content } from "@opal/layouts";
import { SvgEmpty } from "@opal/icons";
import type { IconFunctionComponent, PaddingVariants } from "@opal/types";
import type { ContainerSizeVariants } from "@opal/types";
import type { IconFunctionComponent } from "@opal/types";
// ---------------------------------------------------------------------------
// Types
@@ -14,8 +15,8 @@ type EmptyMessageCardProps = {
/** Primary message text. */
title: string;
/** Padding preset for the card. */
paddingVariant?: PaddingVariants;
/** Size preset controlling padding and rounding of the card. */
sizeVariant?: ContainerSizeVariants;
/** Ref forwarded to the root Card div. */
ref?: React.Ref<HTMLDivElement>;
@@ -28,7 +29,7 @@ type EmptyMessageCardProps = {
function EmptyMessageCard({
icon = SvgEmpty,
title,
paddingVariant = "sm",
sizeVariant = "lg",
ref,
}: EmptyMessageCardProps) {
return (
@@ -36,7 +37,7 @@ function EmptyMessageCard({
ref={ref}
backgroundVariant="none"
borderVariant="dashed"
paddingVariant={paddingVariant}
sizeVariant={sizeVariant}
>
<Content
icon={icon}

View File

@@ -24,7 +24,6 @@ type TextFont =
| "secondary-body"
| "secondary-action"
| "secondary-mono"
| "secondary-mono-label"
| "figure-small-label"
| "figure-small-value"
| "figure-keystroke";
@@ -89,7 +88,6 @@ const FONT_CONFIG: Record<TextFont, string> = {
"secondary-body": "font-secondary-body",
"secondary-action": "font-secondary-action",
"secondary-mono": "font-secondary-mono",
"secondary-mono-label": "font-secondary-mono-label",
"figure-small-label": "font-figure-small-label",
"figure-small-value": "font-figure-small-value",
"figure-keystroke": "font-figure-keystroke",

View File

@@ -9,7 +9,7 @@ Structural container shared by both `Interactive.Stateless` and `Interactive.Sta
| Prop | Type | Default | Description |
|------|------|---------|-------------|
| `heightVariant` | `SizeVariant` | `"lg"` | Height preset (`2xs``lg`, `fit`) |
| `roundingVariant` | `"md" \| "sm" \| "xs"` | `"md"` | Border-radius preset |
| `roundingVariant` | `"default" \| "compact" \| "mini"` | `"default"` | Border-radius preset |
| `widthVariant` | `WidthVariant` | — | Width preset (`"auto"`, `"fit"`, `"full"`) |
| `border` | `boolean` | `false` | Renders a 1px border |
| `type` | `"submit" \| "button" \| "reset"` | — | When set, renders a `<button>` element |
@@ -18,7 +18,7 @@ Structural container shared by both `Interactive.Stateless` and `Interactive.Sta
```tsx
<Interactive.Stateless variant="default" prominence="primary">
<Interactive.Container heightVariant="sm" roundingVariant="sm" border>
<Interactive.Container heightVariant="sm" roundingVariant="compact" border>
<span>Content</span>
</Interactive.Container>
</Interactive.Stateless>

View File

@@ -3,7 +3,7 @@ import type { Route } from "next";
import "@opal/core/interactive/shared.css";
import React from "react";
import { cn } from "@opal/utils";
import type { ButtonType, RoundingVariants, WithoutStyles } from "@opal/types";
import type { ButtonType, WithoutStyles } from "@opal/types";
import {
containerSizeVariants,
type ContainerSizeVariants,
@@ -16,17 +16,19 @@ import { useDisabled } from "@opal/core/disabled/components";
// Types
// ---------------------------------------------------------------------------
type InteractiveContainerRoundingVariant = Extract<
RoundingVariants,
"md" | "sm" | "xs"
>;
const interactiveContainerRoundingVariants: Record<
InteractiveContainerRoundingVariant,
string
> = {
md: "rounded-12",
sm: "rounded-08",
xs: "rounded-04",
/**
* Border-radius presets for `Interactive.Container`.
*
* - `"default"` — Default radius of 0.75rem (12px), matching card rounding
* - `"compact"` — Smaller radius of 0.5rem (8px), for tighter/inline elements
* - `"mini"` — Smallest radius of 0.25rem (4px)
*/
type InteractiveContainerRoundingVariant =
keyof typeof interactiveContainerRoundingVariants;
const interactiveContainerRoundingVariants = {
default: "rounded-12",
compact: "rounded-08",
mini: "rounded-04",
} as const;
/**
@@ -97,7 +99,7 @@ function InteractiveContainer({
ref,
type,
border,
roundingVariant = "md",
roundingVariant = "default",
heightVariant = "lg",
widthVariant = "fit",
...props

View File

@@ -37,35 +37,6 @@ export type SizeVariants = "fit" | "full" | "lg" | "md" | "sm" | "xs" | "2xs";
*/
export type ContainerSizeVariants = Exclude<SizeVariants, "full">;
/**
* Padding size variants.
*
* | Variant | Class |
* |---------|---------|
* | `lg` | `p-6` |
* | `md` | `p-4` |
* | `sm` | `p-2` |
* | `xs` | `p-1` |
* | `2xs` | `p-0.5` |
* | `fit` | `p-0` |
*/
export type PaddingVariants = Extract<
SizeVariants,
"fit" | "lg" | "md" | "sm" | "xs" | "2xs"
>;
/**
* Rounding size variants.
*
* | Variant | Class |
* |---------|--------------|
* | `lg` | `rounded-16` |
* | `md` | `rounded-12` |
* | `sm` | `rounded-08` |
* | `xs` | `rounded-04` |
*/
export type RoundingVariants = Extract<SizeVariants, "lg" | "md" | "sm" | "xs">;
/**
* Extreme size variants ("fit" and "full" only).
*

View File

@@ -8,7 +8,6 @@ import * as Sentry from "@sentry/nextjs";
if (process.env.NEXT_PUBLIC_SENTRY_DSN) {
Sentry.init({
dsn: process.env.NEXT_PUBLIC_SENTRY_DSN,
release: process.env.SENTRY_RELEASE,
// Only capture unhandled exceptions
tracesSampleRate: 0,
debug: false,

View File

@@ -7,7 +7,6 @@ import * as Sentry from "@sentry/nextjs";
if (process.env.NEXT_PUBLIC_SENTRY_DSN) {
Sentry.init({
dsn: process.env.NEXT_PUBLIC_SENTRY_DSN,
release: process.env.SENTRY_RELEASE,
// Setting this option to true will print useful information to the console while you're setting up Sentry.
debug: false,

View File

@@ -0,0 +1,153 @@
import { Form, Formik } from "formik";
import { toast } from "@/hooks/useToast";
import { createApiKey, updateApiKey } from "./lib";
import Modal from "@/refresh-components/Modal";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import Text from "@/refresh-components/texts/Text";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { FormikField } from "@/refresh-components/form/FormikField";
import { FormField } from "@/refresh-components/form/FormField";
import { USER_ROLE_LABELS, UserRole } from "@/lib/types";
import { APIKey } from "./types";
import { SvgKey } from "@opal/icons";
export interface OnyxApiKeyFormProps {
onClose: () => void;
onCreateApiKey: (apiKey: APIKey) => void;
apiKey?: APIKey;
}
export default function OnyxApiKeyForm({
onClose,
onCreateApiKey,
apiKey,
}: OnyxApiKeyFormProps) {
const isUpdate = apiKey !== undefined;
return (
<Modal open onOpenChange={onClose}>
<Modal.Content width="sm" height="lg">
<Modal.Header
icon={SvgKey}
title={isUpdate ? "Update API Key" : "Create a new API Key"}
onClose={onClose}
/>
<Formik
initialValues={{
name: apiKey?.api_key_name || "",
role: apiKey?.api_key_role || UserRole.BASIC.toString(),
}}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
// Prepare the payload with the UserRole
const payload = {
...values,
role: values.role as UserRole, // Assign the role directly as a UserRole type
};
let response;
if (isUpdate) {
response = await updateApiKey(apiKey.api_key_id, payload);
} else {
response = await createApiKey(payload);
}
formikHelpers.setSubmitting(false);
if (response.ok) {
toast.success(
isUpdate
? "Successfully updated API key!"
: "Successfully created API key!"
);
if (!isUpdate) {
onCreateApiKey(await response.json());
}
onClose();
} else {
const responseJson = await response.json();
const errorMsg = responseJson.detail || responseJson.message;
toast.error(
isUpdate
? `Error updating API key - ${errorMsg}`
: `Error creating API key - ${errorMsg}`
);
}
}}
>
{({ isSubmitting }) => (
<Form className="w-full overflow-visible">
<Modal.Body>
<Text as="p">
Choose a memorable name for your API key. This is optional and
can be added or changed later!
</Text>
<FormikField<string>
name="name"
render={(field, helper, _meta, state) => (
<FormField name="name" state={state} className="w-full">
<FormField.Label>Name (optional):</FormField.Label>
<FormField.Control>
<InputTypeIn
{...field}
placeholder=""
onClear={() => helper.setValue("")}
showClearButton={false}
/>
</FormField.Control>
</FormField>
)}
/>
<FormikField<string>
name="role"
render={(field, helper, _meta, state) => (
<FormField name="role" state={state} className="w-full">
<FormField.Label>Role:</FormField.Label>
<FormField.Control>
<InputSelect
value={field.value}
onValueChange={(value) => helper.setValue(value)}
>
<InputSelect.Trigger placeholder="Select a role" />
<InputSelect.Content>
<InputSelect.Item
value={UserRole.LIMITED.toString()}
>
{USER_ROLE_LABELS[UserRole.LIMITED]}
</InputSelect.Item>
<InputSelect.Item value={UserRole.BASIC.toString()}>
{USER_ROLE_LABELS[UserRole.BASIC]}
</InputSelect.Item>
<InputSelect.Item value={UserRole.ADMIN.toString()}>
{USER_ROLE_LABELS[UserRole.ADMIN]}
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</FormField.Control>
<FormField.Description>
Select the role for this API key. Limited has access to
simple public APIs. Basic has access to regular user
APIs. Admin has access to admin level APIs.
</FormField.Description>
</FormField>
)}
/>
</Modal.Body>
<Modal.Footer>
<Disabled disabled={isSubmitting}>
<Button type="submit">
{isUpdate ? "Update" : "Create"}
</Button>
</Disabled>
</Modal.Footer>
</Form>
)}
</Formik>
</Modal.Content>
</Modal>
);
}

Some files were not shown because too many files have changed in this diff Show More