Compare commits

..

11 Commits

Author SHA1 Message Date
Nik
bab95d8bf0 fix(chat): remove duplicate drain_done declaration after rebase 2026-03-31 20:02:29 -07:00
Nik
eb7bc74e1b fix(chat): persist LLM response on HTTP disconnect via drain_done + worker self-completion
When the HTTP client disconnects, Starlette throws GeneratorExit into the
drain loop generator. The old code called executor.shutdown(wait=False) with
no completion handling, leaving the assistant DB message as the TERMINATED
placeholder forever (regressing test_send_message_disconnect_and_cleanup).

New design:
- drain_done (threading.Event) signals emitters to return immediately instead
  of blocking on queue.put — no retry loops, no daemon threads
- One-time queue drain in the else block releases any in-progress puts so
  workers exit within milliseconds
- Workers self-complete: after run_llm_loop returns, each worker checks
  drain_done.is_set() and, if true, opens its own DB session and calls
  llm_loop_completion_handle directly

Unit test updated to reflect the async self-completion semantics: the test
blocks the worker inside run_llm_loop until gen.close() sets drain_done,
then waits for completion_called inside the patch context (while mocks are
still active) to avoid calling the real get_session_with_current_tenant.
2026-03-31 20:02:29 -07:00
Nik
29da0aefb5 feat(chat): add multi-model parallel streaming (N=2-3 LLMs side-by-side)
Adds support for running 2-3 LLMs in parallel within a single chat turn,
with responses streamed interleaved to the frontend via the merged queue
infrastructure introduced in the preceding PR.

Backend changes
- process_message.py: restore llm_overrides param on build_chat_turn and
  _stream_chat_turn; restore is_multi branching for LLM setup, context
  window sizing, and message ID reservation; add _build_model_display_name
  and handle_multi_model_stream (public multi-model entrypoint)
- db/chat.py: add reserve_multi_model_message_ids (reserves N assistant
  message placeholders sharing the same parent), set_preferred_response
  (marks one response as the user's preferred), and extend
  translate_db_message_to_chat_message_detail with preferred_response_id
  and model_display_name fields
- chat_backend.py: route requests with llm_overrides >1 through
  handle_multi_model_stream; reject non-streaming multi-model requests with
  OnyxError; add /set-preferred-response endpoint

Tests
- test_multi_model_streaming.py: unit tests for _run_models drain loop
  (arrival-order yield, error isolation, cancellation), handle_multi_model_stream
  validation guards, and N=1 backwards-compatibility
2026-03-31 20:01:21 -07:00
Nik
6c86301c51 fix(chat): remove bounded queue and packet drops — match old behavior
Old code used queue.Queue() (unbounded, blocking put). New code introduced
queue.Queue(maxsize=100) + put(timeout=3.0) + silent drop on queue.Full —
a regression in all three callsites:

- Emitter.emit(): data packets silently dropped on queue full
- _run_model exception path: model errors silently lost
- _run_model finally (_MODEL_DONE): if dropped, drain loop hangs forever
  (models_remaining never reaches 0)

Fix: remove maxsize, remove all timeout= arguments, remove all
except queue.Full handlers. The drain_done early-return in emit() is the
correct disconnect mechanism; queue backpressure is not needed.

Also adds _completion_done: bool type annotation and fixes the queue drain
comment (no longer unblocking timed-out puts — just releasing memory).
2026-03-31 20:00:46 -07:00
Nik
631146f48f fix(chat): use model_succeeded instead of check_is_connected on self-completion
On HTTP disconnect, check_is_connected() returns False, causing
llm_loop_completion_handle to treat a completed response as
user-cancelled and append "Generation was stopped by the user."
Use lambda: model_succeeded[model_idx] (always True here) instead,
matching the cancellation path's functools.partial(bool, model_succeeded[i]).
2026-03-31 18:42:04 -07:00
Nik
f327278506 fix(chat): persist LLM response on HTTP disconnect via drain_done + worker self-completion
When the HTTP client disconnects, Starlette throws GeneratorExit into the
drain loop generator. The old else branch just called executor.shutdown(wait=False)
with no completion handling, leaving the assistant DB message as the TERMINATED
placeholder forever (regressing test_send_message_disconnect_and_cleanup).

New design:
- drain_done (threading.Event) signals emitters to return immediately instead
  of blocking on queue.put — no retry loops, no daemon threads
- One-time queue drain in the else block releases any in-progress puts so
  workers exit within milliseconds
- Workers self-complete: after run_llm_loop returns, each worker checks
  drain_done.is_set() and, if true, opens its own DB session and calls
  llm_loop_completion_handle directly
2026-03-31 18:14:50 -07:00
Nik
c7cc439862 fix(emitter): address Greptile P1/P2/P3 and Queue typing
- P1: executor.shutdown(wait=False) on early exit — don't block the
  server thread waiting for LLM workers; they will hit queue.Full
  timeouts and exit on their own (matches old run_chat_loop behavior)
- P2: wrap db_session.commit() in try/finally in build_chat_turn —
  reset processing status before propagating if commit fails, so the
  chat session isn't stuck at "processing" permanently
- P3: fix inaccurate comment "All worker threads have exited" — workers
  may still be closing their own DB sessions at that point; clarify
  that only the main-thread db_session is safe to use
- Queue[Any] → Queue[tuple[int, Packet | Exception | object]] in Emitter
2026-03-31 17:02:46 -07:00
Nik
3365a369e2 fix(review): address Greptile comments
- Add owner to bare TODO comment
- Restore placement field assertions weakened by Emitter refactor
2026-03-31 12:49:09 -07:00
Nik
470bda3fb5 refactor(chat): elegance pass on PR1 changed files
process_message.py:
- Fix `skip_clarification` field in ChatTurnSetup: inline comment inside
  the type annotation → separate `#` comment on the line above the field
- Flatten `model_tools` via list comprehension instead of manual extend loop
- `forced_tool_id` membership test: list → set comprehension (O(1) lookup)
- Trim `_run_model` inner-function docstring — private closure doesn't need
  10-line Args block
- Remove redundant inline param comments from `_stream_chat_turn` and
  `handle_stream_message_objects` where the docstring Args section already
  documents them
- Strip duplicate Args/Returns from `handle_stream_message_objects` docstring
  — it delegates entirely to `_stream_chat_turn`

emitter.py:
- Widen `merged_queue` annotation to `Queue[Any]`: Queue is invariant so
  `Queue[tuple[int, Packet]]` can't be passed a `Queue[tuple[int, Packet |
  Exception | object]]`; the emitter is a write-only producer and doesn't
  care what else lives on the queue
2026-03-31 12:16:38 -07:00
Nik
13f511e209 refactor(emitter): clean up string annotation and use model_copy
- Fix `"Queue"` forward-reference annotation → `Queue[tuple[int, Packet]]`
  (Queue is already imported, the string was unnecessary)
- Replace manual Placement field copy with `base.model_copy(update={...})`
- Remove redundant `key` variable (was just `self._model_idx`)
- Tighten docstring
2026-03-31 11:44:28 -07:00
Nik
c5e8ba1eab refactor(chat): replace bus-polling emitter with merged-queue streaming; fix 429 hang
Switch Emitter from a per-model event bus + polling thread to a single
bounded queue shared across all models.  Each emit() call puts directly onto
the queue; the drain loop in _run_models yields packets in arrival order.

Key changes
- emitter.py: remove Bus, get_default_emitter(); add Emitter(merged_queue, model_idx)
- chat_state.py: remove run_chat_loop_with_state_containers (113-line bus-poll loop)
- process_message.py: add ChatTurnSetup dataclass and build_chat_turn(); rewrite
  _stream_chat_turn + _run_models around the merged queue; single-model (N=1)
  path is fully backwards-compatible
- placement.py, override_models.py: add docstrings; LLMOverride gains display_name
- research_agent.py, custom_tool.py: update Emitter call sites
- test_emitter.py: new unit tests for queue routing, model_index tagging, placement

Frontend 429 fix
- lib.tsx: parse response body for human-readable detail on non-2xx responses
  instead of "HTTP error! status: 429"
- useChatController.ts: surface stack.error after the FIFO drain loop exits so
  the catch block replaces the thinking placeholder with an error message
2026-03-30 22:18:48 -07:00
241 changed files with 5291 additions and 11373 deletions

View File

@@ -704,9 +704,6 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
SENTRY_RELEASE=${{ github.sha }}
secrets: |
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
@@ -789,9 +786,6 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
SENTRY_RELEASE=${{ github.sha }}
secrets: |
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest

View File

@@ -41,7 +41,7 @@ jobs:
version: v3.19.0
- name: Set up chart-testing
uses: helm/chart-testing-action@2e2940618cb426dce2999631d543b53cdcfc8527
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
with:
uv_version: "0.9.9"

View File

@@ -1,108 +0,0 @@
"""backfill_account_type
Revision ID: 03d085c5c38d
Revises: 977e834c1427
Create Date: 2026-03-25 16:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "03d085c5c38d"
down_revision = "977e834c1427"
branch_labels = None
depends_on = None
_STANDARD = "STANDARD"
_BOT = "BOT"
_EXT_PERM_USER = "EXT_PERM_USER"
_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
_ANONYMOUS = "ANONYMOUS"
# Well-known anonymous user UUID
ANONYMOUS_USER_ID = "00000000-0000-0000-0000-000000000002"
# Email pattern for API key virtual users
API_KEY_EMAIL_PATTERN = r"API\_KEY\_\_%"
# Reflect the table structure for use in DML
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("email", sa.String),
sa.column("role", sa.String),
sa.column("account_type", sa.String),
)
def upgrade() -> None:
# ------------------------------------------------------------------
# Step 1: Backfill account_type from role.
# Order matters — most-specific matches first so the final catch-all
# only touches rows that haven't been classified yet.
# ------------------------------------------------------------------
# 1a. API key virtual users → SERVICE_ACCOUNT
op.execute(
sa.update(user_table)
.where(
user_table.c.email.ilike(API_KEY_EMAIL_PATTERN),
user_table.c.account_type.is_(None),
)
.values(account_type=_SERVICE_ACCOUNT)
)
# 1b. Anonymous user → ANONYMOUS
op.execute(
sa.update(user_table)
.where(
user_table.c.id == ANONYMOUS_USER_ID,
user_table.c.account_type.is_(None),
)
.values(account_type=_ANONYMOUS)
)
# 1c. SLACK_USER role → BOT
op.execute(
sa.update(user_table)
.where(
user_table.c.role == "SLACK_USER",
user_table.c.account_type.is_(None),
)
.values(account_type=_BOT)
)
# 1d. EXT_PERM_USER role → EXT_PERM_USER
op.execute(
sa.update(user_table)
.where(
user_table.c.role == "EXT_PERM_USER",
user_table.c.account_type.is_(None),
)
.values(account_type=_EXT_PERM_USER)
)
# 1e. Everything else → STANDARD
op.execute(
sa.update(user_table)
.where(user_table.c.account_type.is_(None))
.values(account_type=_STANDARD)
)
# ------------------------------------------------------------------
# Step 2: Set account_type to NOT NULL now that every row is filled.
# ------------------------------------------------------------------
op.alter_column(
"user",
"account_type",
nullable=False,
server_default="STANDARD",
)
def downgrade() -> None:
op.alter_column("user", "account_type", nullable=True, server_default=None)
op.execute(sa.update(user_table).values(account_type=None))

View File

@@ -1,104 +0,0 @@
"""add_effective_permissions
Adds a JSONB column `effective_permissions` to the user table to store
directly granted permissions (e.g. ["admin"] or ["basic"]). Implied
permissions are expanded at read time, not stored.
Backfill: joins user__user_group → permission_grant to collect each
user's granted permissions into a JSON array. Users without group
memberships keep the default [].
Revision ID: 503883791c39
Revises: b4b7e1028dfd
Create Date: 2026-03-30 14:49:22.261748
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "503883791c39"
down_revision = "b4b7e1028dfd"
branch_labels: str | None = None
depends_on: str | Sequence[str] | None = None
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("effective_permissions", postgresql.JSONB),
)
user_user_group = sa.table(
"user__user_group",
sa.column("user_id", sa.Uuid),
sa.column("user_group_id", sa.Integer),
)
permission_grant = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("is_deleted", sa.Boolean),
)
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"effective_permissions",
postgresql.JSONB(),
nullable=False,
server_default=sa.text("'[]'::jsonb"),
),
)
conn = op.get_bind()
# Deduplicated permissions per user
deduped = (
sa.select(
user_user_group.c.user_id,
permission_grant.c.permission,
)
.select_from(
user_user_group.join(
permission_grant,
sa.and_(
permission_grant.c.group_id == user_user_group.c.user_group_id,
permission_grant.c.is_deleted == sa.false(),
),
)
)
.distinct()
.subquery("deduped")
)
# Aggregate into JSONB array per user (order is not guaranteed;
# consumers read this as a set so ordering does not matter)
perms_per_user = (
sa.select(
deduped.c.user_id,
sa.func.jsonb_agg(
deduped.c.permission,
type_=postgresql.JSONB,
).label("perms"),
)
.group_by(deduped.c.user_id)
.subquery("sub")
)
conn.execute(
user_table.update()
.where(user_table.c.id == perms_per_user.c.user_id)
.values(effective_permissions=perms_per_user.c.perms)
)
def downgrade() -> None:
op.drop_column("user", "effective_permissions")

View File

@@ -1,54 +0,0 @@
"""csv to tabular chat file type
Revision ID: 8188861f4e92
Revises: d8cdfee5df80
Create Date: 2026-03-31 19:23:05.753184
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "8188861f4e92"
down_revision = "d8cdfee5df80"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE chat_message
SET files = (
SELECT jsonb_agg(
CASE
WHEN elem->>'type' = 'csv'
THEN jsonb_set(elem, '{type}', '"tabular"')
ELSE elem
END
)
FROM jsonb_array_elements(files) AS elem
)
WHERE files::text LIKE '%"type": "csv"%'
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE chat_message
SET files = (
SELECT jsonb_agg(
CASE
WHEN elem->>'type' = 'tabular'
THEN jsonb_set(elem, '{type}', '"csv"')
ELSE elem
END
)
FROM jsonb_array_elements(files) AS elem
)
WHERE files::text LIKE '%"type": "tabular"%'
"""
)

View File

@@ -1,136 +0,0 @@
"""seed_default_groups
Revision ID: 977e834c1427
Revises: 8188861f4e92
Create Date: 2026-03-25 14:59:41.313091
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import insert as pg_insert
# revision identifiers, used by Alembic.
revision = "977e834c1427"
down_revision = "8188861f4e92"
branch_labels = None
depends_on = None
# (group_name, permission_value)
DEFAULT_GROUPS = [
("Admin", "admin"),
("Basic", "basic"),
]
CUSTOM_SUFFIX = "(Custom)"
MAX_RENAME_ATTEMPTS = 100
# Reflect table structures for use in DML
user_group_table = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("name", sa.String),
sa.column("is_up_to_date", sa.Boolean),
sa.column("is_up_for_deletion", sa.Boolean),
sa.column("is_default", sa.Boolean),
)
permission_grant_table = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("grant_source", sa.String),
)
user__user_group_table = sa.table(
"user__user_group",
sa.column("user_group_id", sa.Integer),
sa.column("user_id", sa.Uuid),
)
def _find_available_name(conn: sa.engine.Connection, base: str) -> str:
"""Return a name like 'Admin (Custom)' or 'Admin (Custom 2)' that is not taken."""
candidate = f"{base} {CUSTOM_SUFFIX}"
attempt = 1
while attempt <= MAX_RENAME_ATTEMPTS:
exists = conn.execute(
sa.select(sa.literal(1))
.select_from(user_group_table)
.where(user_group_table.c.name == candidate)
.limit(1)
).fetchone()
if exists is None:
return candidate
attempt += 1
candidate = f"{base} (Custom {attempt})"
raise RuntimeError(
f"Could not find an available name for group '{base}' "
f"after {MAX_RENAME_ATTEMPTS} attempts"
)
def upgrade() -> None:
conn = op.get_bind()
for group_name, permission_value in DEFAULT_GROUPS:
# Step 1: Rename ALL existing groups that clash with the canonical name.
conflicting = conn.execute(
sa.select(user_group_table.c.id, user_group_table.c.name).where(
user_group_table.c.name == group_name
)
).fetchall()
for row_id, row_name in conflicting:
new_name = _find_available_name(conn, row_name)
op.execute(
sa.update(user_group_table)
.where(user_group_table.c.id == row_id)
.values(name=new_name, is_up_to_date=False)
)
# Step 2: Create a fresh default group.
result = conn.execute(
user_group_table.insert()
.values(
name=group_name,
is_up_to_date=True,
is_up_for_deletion=False,
is_default=True,
)
.returning(user_group_table.c.id)
).fetchone()
assert result is not None
group_id = result[0]
# Step 3: Upsert permission grant.
op.execute(
pg_insert(permission_grant_table)
.values(
group_id=group_id,
permission=permission_value,
grant_source="SYSTEM",
)
.on_conflict_do_nothing(index_elements=["group_id", "permission"])
)
def downgrade() -> None:
# Remove the default groups created by this migration.
# First remove user-group memberships that reference default groups
# to avoid FK violations, then delete the groups themselves.
default_group_ids = sa.select(user_group_table.c.id).where(
user_group_table.c.is_default == True # noqa: E712
)
op.execute(
sa.delete(user__user_group_table).where(
user__user_group_table.c.user_group_id.in_(default_group_ids)
)
)
op.execute(
sa.delete(user_group_table).where(
user_group_table.c.is_default == True # noqa: E712
)
)

View File

@@ -1,84 +0,0 @@
"""grant_basic_to_existing_groups
Grants the "basic" permission to all existing groups that don't already
have it. Every group should have at least "basic" so that its members
get basic access when effective_permissions is backfilled.
Revision ID: b4b7e1028dfd
Revises: b7bcc991d722
Create Date: 2026-03-30 16:15:17.093498
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b4b7e1028dfd"
down_revision = "b7bcc991d722"
branch_labels: str | None = None
depends_on: str | Sequence[str] | None = None
user_group = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("is_default", sa.Boolean),
)
permission_grant = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("grant_source", sa.String),
sa.column("is_deleted", sa.Boolean),
)
def upgrade() -> None:
conn = op.get_bind()
already_has_basic = (
sa.select(sa.literal(1))
.select_from(permission_grant)
.where(
permission_grant.c.group_id == user_group.c.id,
permission_grant.c.permission == "basic",
)
.exists()
)
groups_needing_basic = sa.select(
user_group.c.id,
sa.literal("basic").label("permission"),
sa.literal("SYSTEM").label("grant_source"),
sa.literal(False).label("is_deleted"),
).where(
user_group.c.is_default == sa.false(),
~already_has_basic,
)
conn.execute(
permission_grant.insert().from_select(
["group_id", "permission", "grant_source", "is_deleted"],
groups_needing_basic,
)
)
def downgrade() -> None:
conn = op.get_bind()
non_default_group_ids = sa.select(user_group.c.id).where(
user_group.c.is_default == sa.false()
)
conn.execute(
permission_grant.delete().where(
permission_grant.c.permission == "basic",
permission_grant.c.grant_source == "SYSTEM",
permission_grant.c.group_id.in_(non_default_group_ids),
)
)

View File

@@ -1,116 +0,0 @@
"""assign_users_to_default_groups
Revision ID: b7bcc991d722
Revises: 03d085c5c38d
Create Date: 2026-03-25 16:30:39.529301
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import insert as pg_insert
# revision identifiers, used by Alembic.
revision = "b7bcc991d722"
down_revision = "03d085c5c38d"
branch_labels = None
depends_on = None
# Reflect table structures for use in DML
user_group_table = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("name", sa.String),
sa.column("is_default", sa.Boolean),
)
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("role", sa.String),
sa.column("account_type", sa.String),
sa.column("is_active", sa.Boolean),
)
user__user_group_table = sa.table(
"user__user_group",
sa.column("user_group_id", sa.Integer),
sa.column("user_id", sa.Uuid),
)
def upgrade() -> None:
conn = op.get_bind()
# Look up default group IDs
admin_row = conn.execute(
sa.select(user_group_table.c.id).where(
user_group_table.c.name == "Admin",
user_group_table.c.is_default == True, # noqa: E712
)
).fetchone()
basic_row = conn.execute(
sa.select(user_group_table.c.id).where(
user_group_table.c.name == "Basic",
user_group_table.c.is_default == True, # noqa: E712
)
).fetchone()
if admin_row is None:
raise RuntimeError(
"Default 'Admin' group not found. "
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
)
if basic_row is None:
raise RuntimeError(
"Default 'Basic' group not found. "
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
)
# Users with role=admin → Admin group
# Exclude inactive placeholder/anonymous users that are not real users
admin_users = sa.select(
sa.literal(admin_row[0]).label("user_group_id"),
user_table.c.id.label("user_id"),
).where(
user_table.c.role == "ADMIN",
user_table.c.is_active == True, # noqa: E712
)
op.execute(
pg_insert(user__user_group_table)
.from_select(["user_group_id", "user_id"], admin_users)
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
)
# STANDARD users (non-admin) and SERVICE_ACCOUNT users (role=basic) → Basic group
# Exclude inactive placeholder/anonymous users that are not real users
basic_users = sa.select(
sa.literal(basic_row[0]).label("user_group_id"),
user_table.c.id.label("user_id"),
).where(
user_table.c.is_active == True, # noqa: E712
sa.or_(
sa.and_(
user_table.c.account_type == "STANDARD",
user_table.c.role != "ADMIN",
),
sa.and_(
user_table.c.account_type == "SERVICE_ACCOUNT",
user_table.c.role == "BASIC",
),
),
)
op.execute(
pg_insert(user__user_group_table)
.from_select(["user_group_id", "user_id"], basic_users)
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
)
def downgrade() -> None:
# Group memberships are left in place — removing them risks
# deleting memberships that existed before this migration.
pass

View File

@@ -1,55 +0,0 @@
"""add skipped to userfilestatus
Revision ID: d8cdfee5df80
Revises: 1d78c0ca7853
Create Date: 2026-04-01 10:47:12.593950
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d8cdfee5df80"
down_revision = "1d78c0ca7853"
branch_labels = None
depends_on = None
TABLE = "user_file"
COLUMN = "status"
CONSTRAINT_NAME = "ck_user_file_status"
OLD_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
NEW_VALUES = (
"PROCESSING",
"INDEXING",
"COMPLETED",
"SKIPPED",
"FAILED",
"CANCELED",
"DELETING",
)
def _drop_status_check_constraint() -> None:
inspector = sa.inspect(op.get_bind())
for constraint in inspector.get_check_constraints(TABLE):
if COLUMN in constraint.get("sqltext", ""):
constraint_name = constraint["name"]
if constraint_name is not None:
op.drop_constraint(constraint_name, TABLE, type_="check")
def upgrade() -> None:
_drop_status_check_constraint()
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
def downgrade() -> None:
op.execute(f"UPDATE {TABLE} SET {COLUMN} = 'COMPLETED' WHERE {COLUMN} = 'SKIPPED'")
_drop_status_check_constraint()
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")

View File

@@ -5,7 +5,6 @@ from onyx.background.celery.apps.primary import celery_app
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.hooks",
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cloud",

View File

@@ -55,15 +55,6 @@ ee_tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
ee_tasks_to_schedule = [
{
"name": "hook-execution-log-cleanup",
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
"schedule": timedelta(days=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,

View File

@@ -13,7 +13,6 @@ from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.provisioning import setup_tenant
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
@@ -30,10 +29,9 @@ from shared_configs.configs import TENANT_ID_PREFIX
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
_MAX_TENANTS_PER_RUN = 5
# Time limits sized for worst-case: provisioning up to _MAX_TENANTS_PER_RUN new tenants
# (~90s each) plus migrating up to TARGET_AVAILABLE_TENANTS pool tenants (~90s each).
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 20 # 20 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 25 # 25 minutes
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
@shared_task(
@@ -93,7 +91,8 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
if batch_size < tenants_to_provision:
task_logger.info(
f"Capping batch to {batch_size} (need {tenants_to_provision}, will catch up next cycle)"
f"Capping batch to {batch_size} "
f"(need {tenants_to_provision}, will catch up next cycle)"
)
provisioned = 0
@@ -104,14 +103,12 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
provisioned += 1
except Exception:
task_logger.exception(
f"Failed to provision tenant {i + 1}/{batch_size}, continuing with remaining tenants"
f"Failed to provision tenant {i + 1}/{batch_size}, "
"continuing with remaining tenants"
)
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
# Migrate any pool tenants that were provisioned before a new migration was deployed
_migrate_stale_pool_tenants()
except Exception:
task_logger.exception("Error in check_available_tenants task")
@@ -124,46 +121,6 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
)
def _migrate_stale_pool_tenants() -> None:
"""
Run alembic upgrade head on all pool tenants. Since alembic upgrade head is
idempotent, tenants already at head are a fast no-op. This ensures pool
tenants are always current so that signup doesn't hit schema mismatches
(e.g. missing columns added after the tenant was pre-provisioned).
"""
with get_session_with_shared_schema() as db_session:
pool_tenants = db_session.query(AvailableTenant).all()
tenant_ids = [t.tenant_id for t in pool_tenants]
if not tenant_ids:
return
task_logger.info(
f"Checking {len(tenant_ids)} pool tenant(s) for pending migrations"
)
for tenant_id in tenant_ids:
try:
run_alembic_migrations(tenant_id)
new_version = get_current_alembic_version(tenant_id)
with get_session_with_shared_schema() as db_session:
tenant = (
db_session.query(AvailableTenant)
.filter_by(tenant_id=tenant_id)
.first()
)
if tenant and tenant.alembic_version != new_version:
task_logger.info(
f"Migrated pool tenant {tenant_id}: {tenant.alembic_version} -> {new_version}"
)
tenant.alembic_version = new_version
db_session.commit()
except Exception:
task_logger.exception(
f"Failed to migrate pool tenant {tenant_id}, skipping"
)
def pre_provision_tenant() -> bool:
"""
Pre-provision a new tenant and store it in the NewAvailableTenant table.

View File

@@ -69,7 +69,5 @@ EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset(
"/admin/token-rate-limits",
# Evals
"/evals",
# Hook extensions
"/admin/hooks",
}
)

View File

@@ -19,8 +19,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import GrantSource
from onyx.db.enums import Permission
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
@@ -30,7 +28,6 @@ from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import PermissionGrant
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
@@ -39,7 +36,6 @@ from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.models import UserRole
from onyx.db.permissions import recompute_user_permissions__no_commit
from onyx.db.users import fetch_user_by_id
from onyx.utils.logger import setup_logger
@@ -259,7 +255,6 @@ def fetch_user_groups(
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
include_default: bool = True,
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -274,7 +269,6 @@ def fetch_user_groups(
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
include_default: If False, excludes system default groups (is_default=True).
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -282,8 +276,6 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
@@ -294,7 +286,6 @@ def fetch_user_groups_for_user(
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
include_default: bool = True,
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -304,8 +295,6 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
@@ -489,16 +478,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
db_session.add(db_user_group)
db_session.flush() # give the group an ID
# Every group gets the "basic" permission by default
db_session.add(
PermissionGrant(
group_id=db_user_group.id,
permission=Permission.BASIC_ACCESS,
grant_source=GrantSource.SYSTEM,
)
)
db_session.flush()
_add_user__user_group_relationships__no_commit(
db_session=db_session,
user_group_id=db_user_group.id,
@@ -510,9 +489,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
cc_pair_ids=user_group.cc_pair_ids,
)
for uid in user_group.user_ids:
recompute_user_permissions__no_commit(uid, db_session)
db_session.commit()
return db_user_group
@@ -820,9 +796,6 @@ def update_user_group(
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()
for uid in set(added_user_ids) | set(removed_user_ids):
recompute_user_permissions__no_commit(uid, db_session)
db_session.commit()
return db_user_group
@@ -862,17 +835,6 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
_check_user_group_is_modifiable(db_user_group)
# Collect affected user IDs before cleanup deletes the relationships
affected_user_ids = (
db_session.execute(
select(User__UserGroup.user_id).where(
User__UserGroup.user_group_id == user_group_id
)
)
.scalars()
.all()
)
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
@@ -901,11 +863,6 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
db_session=db_session, user_group_id=user_group_id
)
# Recompute permissions for affected users now that their
# membership in this group has been removed
for uid in affected_user_ids:
recompute_user_permissions__no_commit(uid, db_session)
db_user_group.is_up_to_date = False
db_user_group.is_up_for_deletion = True
db_session.commit()

View File

@@ -1,385 +0,0 @@
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
Usage (Celery tasks and FastAPI handlers):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
response_type=QueryProcessingResponse,
)
if isinstance(result, HookSkipped):
# no active hook configured — continue with original behavior
...
elif isinstance(result, HookSoftFailed):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is a validated Pydantic model instance (response_type)
...
is_reachable update policy
--------------------------
``is_reachable`` on the Hook row is updated selectively — only when the outcome
carries meaningful signal about physical reachability:
NetworkError (DNS, connection refused) → False (cannot reach the server)
HTTP 401 / 403 → False (api_key revoked or invalid)
TimeoutException → None (server may be slow, skip write)
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
Unknown exception → None (no signal, skip write)
Non-JSON / non-dict response → None (server responded, skip write)
Success (2xx, valid dict) → True (confirmed reachable)
None means "leave the current value unchanged" — no DB round-trip is made.
DB session design
-----------------
The executor uses three sessions:
1. Caller's session (db_session) — used only for the hook lookup read. All
needed fields are extracted from the Hook object before the HTTP call, so
the caller's session is not held open during the external HTTP request.
2. Log session — a separate short-lived session opened after the HTTP call
completes to write the HookExecutionLog row on failure. Success runs are
not recorded. Committed independently of everything else.
3. Reachable session — a second short-lived session to update is_reachable on
the Hook. Kept separate from the log session so a concurrent hook deletion
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
prevent the execution log from being written. This update is best-effort.
"""
import json
import time
from typing import Any
from typing import TypeVar
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.hook import create_hook_execution_log__no_commit
from onyx.db.hook import get_non_deleted_hook_by_hook_point
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
T = TypeVar("T", bound=BaseModel)
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
class _HttpOutcome(BaseModel):
"""Structured result of an HTTP hook call, returned by _process_response."""
is_success: bool
updated_is_reachable: (
bool | None
) # True/False = write to DB, None = unchanged (skip write)
status_code: int | None
error_message: str | None
response_payload: dict[str, Any] | None
def _lookup_hook(
db_session: Session,
hook_point: HookPoint,
) -> Hook | HookSkipped:
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
No HTTP call is made and no DB writes are performed for any HookSkipped path.
There is nothing to log and no reachability information to update.
"""
if MULTI_TENANT:
return HookSkipped()
hook = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if hook is None or not hook.is_active:
return HookSkipped()
if not hook.endpoint_url:
return HookSkipped()
return hook
def _process_response(
*,
response: httpx.Response | None,
exc: Exception | None,
timeout: float,
) -> _HttpOutcome:
"""Process the result of an HTTP call and return a structured outcome.
Called after the client.post() try/except. If post() raised, exc is set and
response is None. Otherwise response is set and exc is None. Handles
raise_for_status(), JSON decoding, and the dict shape check.
"""
if exc is not None:
if isinstance(exc, httpx.NetworkError):
msg = f"Hook network error (endpoint unreachable): {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False,
status_code=None,
error_message=msg,
response_payload=None,
)
if isinstance(exc, httpx.TimeoutException):
msg = f"Hook timed out after {timeout}s: {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # timeout doesn't indicate unreachability
status_code=None,
error_message=msg,
response_payload=None,
)
msg = f"Hook call failed: {exc}"
logger.exception(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # unknown error — don't make assumptions
status_code=None,
error_message=msg,
response_payload=None,
)
if response is None:
raise ValueError(
"exactly one of response or exc must be non-None; both are None"
)
status_code = response.status_code
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
logger.warning(msg, exc_info=e)
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
# so the operator knows to update it. All other HTTP errors keep is_reachable
# as-is (server is up, the request just failed for application reasons).
auth_failed = e.response.status_code in (401, 403)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False if auth_failed else None,
status_code=status_code,
error_message=msg,
response_payload=None,
)
try:
response_payload = response.json()
except (json.JSONDecodeError, httpx.DecodingError) as e:
msg = f"Hook returned non-JSON response: {e}"
logger.warning(msg, exc_info=e)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
if not isinstance(response_payload, dict):
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
logger.warning(msg)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
return _HttpOutcome(
is_success=True,
updated_is_reachable=True,
status_code=status_code,
error_message=None,
response_payload=response_payload,
)
def _persist_result(
*,
hook_id: int,
outcome: _HttpOutcome,
duration_ms: int,
) -> None:
"""Write the execution log on failure and optionally update is_reachable, each
in its own session so a failure in one does not affect the other."""
# Only write the execution log on failure — success runs are not recorded.
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
# deleted between the initial lookup and here).
if not outcome.is_success:
try:
with get_session_with_current_tenant() as log_session:
create_hook_execution_log__no_commit(
db_session=log_session,
hook_id=hook_id,
is_success=False,
error_message=outcome.error_message,
status_code=outcome.status_code,
duration_ms=duration_ms,
)
log_session.commit()
except Exception:
logger.exception(
f"Failed to persist hook execution log for hook_id={hook_id}"
)
# Update is_reachable separately — best-effort, non-critical.
# None means the value is unchanged (set by the caller to skip the no-op write).
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
# concurrently deleted, so keep this isolated from the log write above.
if outcome.updated_is_reachable is not None:
try:
with get_session_with_current_tenant() as reachable_session:
update_hook__no_commit(
db_session=reachable_session,
hook_id=hook_id,
is_reachable=outcome.updated_is_reachable,
)
reachable_session.commit()
except Exception:
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def _execute_hook_inner(
hook: Hook,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
"active hooks without an endpoint_url must be rejected by _lookup_hook"
)
start = time.monotonic()
response: httpx.Response | None = None
exc: Exception | None = None
try:
api_key: str | None = (
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
)
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against response_type.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: T | None = None
if outcome.is_success and outcome.response_payload is not None:
try:
validated_model = response_type.model_validate(outcome.response_payload)
except ValidationError as e:
msg = (
f"Hook response failed validation against {response_type.__name__}: {e}"
)
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
outcome = outcome.model_copy(update={"updated_is_reachable": None})
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
if not outcome.is_success:
if fail_strategy == HookFailStrategy.HARD:
raise OnyxError(
OnyxErrorCode.HOOK_EXECUTION_FAILED,
outcome.error_message or "Hook execution failed.",
)
logger.warning(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if validated_model is None:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"validated_model is None for successful hook call (hook_id={hook_id})",
)
return validated_model
def _execute_hook_impl(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSkipped | HookSoftFailed:
"""EE implementation — loaded by CE's execute_hook via fetch_versioned_implementation.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
"""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
fail_strategy = hook.fail_strategy
hook_id = hook.id
try:
return _execute_hook_inner(hook, payload, response_type)
except Exception:
if fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
)
return HookSoftFailed()
raise

View File

@@ -15,7 +15,6 @@ from ee.onyx.server.enterprise_settings.api import (
basic_router as enterprise_settings_router,
)
from ee.onyx.server.evals.api import router as evals_router
from ee.onyx.server.features.hooks.api import router as hook_router
from ee.onyx.server.license.api import router as license_router
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.license_enforcement import (
@@ -139,7 +138,6 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, ee_oauth_router)
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
include_router_with_global_prefix_prepended(application, evals_router)
include_router_with_global_prefix_prepended(application, hook_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(

View File

@@ -52,13 +52,11 @@ from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccountType
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -488,7 +486,6 @@ def create_user(
email=email,
hashed_password=_pw_helper.hash(_pw_helper.generate()),
role=UserRole.BASIC,
account_type=AccountType.STANDARD,
is_active=user_resource.active,
is_verified=True,
personal_name=personal_name,
@@ -509,25 +506,13 @@ def create_user(
scim_username=scim_username,
fields=fields,
)
dal.commit()
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"User with email {email} already has a SCIM mapping"
)
# Assign user to default group BEFORE commit so everything is atomic.
# If this fails, the entire user creation rolls back and IdP can retry.
try:
assign_user_to_default_groups__no_commit(db_session, user)
except Exception:
dal.rollback()
logger.exception(f"Failed to assign SCIM user {email} to default groups")
return _scim_error_response(
500, f"Failed to assign user {email} to default group"
)
dal.commit()
return _scim_resource_response(
provider.build_user_resource(
user,

View File

@@ -99,26 +99,6 @@ async def get_or_provision_tenant(
tenant_id = await get_available_tenant()
if tenant_id:
# Run migrations to ensure the pre-provisioned tenant schema is current.
# Pool tenants may have been created before a new migration was deployed.
# Capture as a non-optional local so mypy can type the lambda correctly.
_tenant_id: str = tenant_id
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(
None, lambda: run_alembic_migrations(_tenant_id)
)
except Exception:
# The tenant was already dequeued from the pool — roll it back so
# it doesn't end up orphaned (schema exists, but not assigned to anyone).
logger.exception(
f"Migration failed for pre-provisioned tenant {_tenant_id}; rolling back"
)
try:
await rollback_tenant_provisioning(_tenant_id)
except Exception:
logger.exception(f"Failed to rollback orphaned tenant {_tenant_id}")
raise
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")

View File

@@ -43,16 +43,12 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
@router.get("/admin/user-group")
def list_user_groups(
include_default: bool = False,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session,
only_up_to_date=False,
eager_load_for_snapshot=True,
include_default=include_default,
db_session, only_up_to_date=False, eager_load_for_snapshot=True
)
else:
user_groups = fetch_user_groups_for_user(
@@ -60,50 +56,27 @@ def list_user_groups(
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
include_default=include_default,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]
@router.get("/user-groups/minimal")
def list_minimal_user_groups(
include_default: bool = False,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[MinimalUserGroupSnapshot]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session,
only_up_to_date=False,
include_default=include_default,
)
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
include_default=include_default,
)
return [
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
]
@router.get("/admin/user-group/{user_group_id}/permissions")
def get_user_group_permissions(
user_group_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[str]:
group = fetch_user_group(db_session, user_group_id)
if group is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
return [
grant.permission.value
for grant in group.permission_grants
if not grant.is_deleted
]
@router.post("/admin/user-group")
def create_user_group(
user_group: UserGroupCreate,
@@ -127,9 +100,6 @@ def rename_user_group_endpoint(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
group = fetch_user_group(db_session, rename_request.id)
if group and group.is_default:
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot rename a default system group.")
try:
return UserGroup.from_model(
rename_user_group(
@@ -215,9 +185,6 @@ def delete_user_group(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
group = fetch_user_group(db_session, user_group_id)
if group and group.is_default:
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot delete a default system group.")
try:
prepare_user_group_for_deletion(db_session, user_group_id)
except ValueError as e:

View File

@@ -22,7 +22,6 @@ class UserGroup(BaseModel):
personas: list[PersonaSnapshot]
is_up_to_date: bool
is_up_for_deletion: bool
is_default: bool
@classmethod
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
@@ -75,21 +74,18 @@ class UserGroup(BaseModel):
],
is_up_to_date=user_group_model.is_up_to_date,
is_up_for_deletion=user_group_model.is_up_for_deletion,
is_default=user_group_model.is_default,
)
class MinimalUserGroupSnapshot(BaseModel):
id: int
name: str
is_default: bool
@classmethod
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
return cls(
id=user_group_model.id,
name=user_group_model.name,
is_default=user_group_model.is_default,
)

View File

@@ -100,7 +100,6 @@ def get_model_app() -> FastAPI:
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -1,110 +0,0 @@
"""
Permission resolution for group-based authorization.
Granted permissions are stored as a JSONB column on the User table and
loaded for free with every auth query. Implied permissions are expanded
at read time — only directly granted permissions are persisted.
"""
from collections.abc import Callable
from collections.abc import Coroutine
from typing import Any
from fastapi import Depends
from onyx.auth.users import current_user
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
logger = setup_logger()
ALL_PERMISSIONS: frozenset[str] = frozenset(p.value for p in Permission)
# Implication map: granted permission -> set of permissions it implies.
IMPLIED_PERMISSIONS: dict[str, set[str]] = {
Permission.ADD_AGENTS.value: {Permission.READ_AGENTS.value},
Permission.MANAGE_AGENTS.value: {
Permission.ADD_AGENTS.value,
Permission.READ_AGENTS.value,
},
Permission.MANAGE_DOCUMENT_SETS.value: {
Permission.READ_DOCUMENT_SETS.value,
Permission.READ_CONNECTORS.value,
},
Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value},
Permission.MANAGE_CONNECTORS.value: {
Permission.ADD_CONNECTORS.value,
Permission.READ_CONNECTORS.value,
},
Permission.MANAGE_USER_GROUPS.value: {
Permission.READ_CONNECTORS.value,
Permission.READ_DOCUMENT_SETS.value,
Permission.READ_AGENTS.value,
Permission.READ_USERS.value,
},
}
def resolve_effective_permissions(granted: set[str]) -> set[str]:
"""Expand granted permissions with their implied permissions.
If "admin" is present, returns all 19 permissions.
"""
if Permission.FULL_ADMIN_PANEL_ACCESS.value in granted:
return set(ALL_PERMISSIONS)
effective = set(granted)
changed = True
while changed:
changed = False
for perm in list(effective):
implied = IMPLIED_PERMISSIONS.get(perm)
if implied and not implied.issubset(effective):
effective |= implied
changed = True
return effective
def get_effective_permissions(user: User) -> set[Permission]:
"""Read granted permissions from the column and expand implied permissions."""
granted: set[Permission] = set()
for p in user.effective_permissions:
try:
granted.add(Permission(p))
except ValueError:
logger.warning(f"Skipping unknown permission '{p}' for user {user.id}")
if Permission.FULL_ADMIN_PANEL_ACCESS in granted:
return set(Permission)
expanded = resolve_effective_permissions({p.value for p in granted})
return {Permission(p) for p in expanded}
def require_permission(
required: Permission,
) -> Callable[..., Coroutine[Any, Any, User]]:
"""FastAPI dependency factory for permission-based access control.
Usage:
@router.get("/endpoint")
def endpoint(user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS))):
...
"""
async def dependency(user: User = Depends(current_user)) -> User:
effective = get_effective_permissions(user)
if Permission.FULL_ADMIN_PANEL_ACCESS in effective:
return user
if required not in effective:
raise OnyxError(
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
"You do not have the required permissions for this action.",
)
return user
return dependency

View File

@@ -5,8 +5,6 @@ from typing import Any
from fastapi_users import schemas
from typing_extensions import override
from onyx.db.enums import AccountType
class UserRole(str, Enum):
"""
@@ -43,7 +41,6 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
account_type: AccountType = AccountType.STANDARD
tenant_id: str | None = None
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
# Excluded from create_update_dict so it never reaches the DB layer
@@ -53,16 +50,12 @@ class UserCreate(schemas.BaseUserCreate):
def create_update_dict(self) -> dict[str, Any]:
d = super().create_update_dict()
d.pop("captcha_token", None)
# Force STANDARD for self-registration; only trusted paths
# (SCIM, API key creation) supply a different account_type directly.
d["account_type"] = AccountType.STANDARD
return d
@override
def create_update_dict_superuser(self) -> dict[str, Any]:
d = super().create_update_dict_superuser()
d.pop("captcha_token", None)
d.setdefault("account_type", self.account_type)
return d

View File

@@ -120,13 +120,11 @@ from onyx.db.engine.async_sql_engine import get_async_session
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.enums import AccountType
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.db.users import get_user_by_email
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import log_onyx_error
@@ -696,7 +694,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
"email": account_email,
"hashed_password": self.password_helper.hash(password),
"is_verified": is_verified_by_default,
"account_type": AccountType.STANDARD,
}
user = await self.user_db.create(user_dict)
@@ -746,23 +743,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
with get_session_with_current_tenant() as sync_db:
enforce_seat_limit(sync_db)
# Upgrade the user and assign default groups in a single
# transaction so neither change is visible without the other.
was_inactive = not user.is_active
with get_session_with_current_tenant() as sync_db:
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
if sync_user:
sync_user.is_verified = is_verified_by_default
sync_user.role = UserRole.BASIC
sync_user.account_type = AccountType.STANDARD
if was_inactive:
sync_user.is_active = True
assign_user_to_default_groups__no_commit(sync_db, sync_user)
sync_db.commit()
# Refresh the async user object so downstream code
# (e.g. oidc_expiry check) sees the updated fields.
user = await self.user_db.get(user.id) # type: ignore[arg-type]
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"role": UserRole.BASIC,
**({"is_active": True} if not user.is_active else {}),
},
)
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
@@ -848,16 +836,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
event=MilestoneRecordType.TENANT_CREATED,
)
# Assign user to the appropriate default group (Admin or Basic).
# Must happen inside the try block while tenant context is active,
# otherwise get_session_with_current_tenant() targets the wrong schema.
is_admin = user_count == 1 or user.email in get_default_admin_user_emails()
with get_session_with_current_tenant() as db_session:
assign_user_to_default_groups__no_commit(
db_session, user, is_admin=is_admin
)
db_session.commit()
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@@ -1576,7 +1554,6 @@ def get_anonymous_user() -> User:
is_verified=True,
is_superuser=False,
role=UserRole.LIMITED,
account_type=AccountType.ANONYMOUS,
use_memories=False,
enable_memory_tool=False,
)

View File

@@ -20,7 +20,6 @@ from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx import __version__
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
@@ -66,7 +65,6 @@ if SENTRY_DSN:
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:
@@ -517,8 +515,7 @@ def reset_tenant_id(
def wait_for_vespa_or_shutdown(
sender: Any, # noqa: ARG001
**kwargs: Any, # noqa: ARG001
sender: Any, **kwargs: Any # noqa: ARG001
) -> None: # noqa: ARG001
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""

View File

@@ -317,6 +317,7 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.hierarchyfetching",
"onyx.background.celery.tasks.hooks",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",

View File

@@ -14,6 +14,7 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.hooks.utils import HOOKS_AVAILABLE
from shared_configs.configs import MULTI_TENANT
# choosing 15 minutes because it roughly gives us enough time to process many tasks
@@ -361,6 +362,19 @@ if not MULTI_TENANT:
tasks_to_schedule.extend(beat_task_templates)
if HOOKS_AVAILABLE:
tasks_to_schedule.append(
{
"name": "hook-execution-log-cleanup",
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
"schedule": timedelta(days=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def generate_cloud_tasks(
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float

View File

@@ -9,7 +9,6 @@ from celery import Celery
from celery import shared_task
from celery import Task
from onyx import __version__
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
@@ -138,7 +137,6 @@ def _docfetching_task(
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -319,11 +319,6 @@ def monitor_indexing_attempt_progress(
)
current_db_time = get_db_current_time(db_session)
total_batches: int | str = (
coordination_status.total_batches
if coordination_status.total_batches is not None
else "?"
)
if coordination_status.found:
task_logger.info(
f"Indexing attempt progress: "
@@ -331,7 +326,7 @@ def monitor_indexing_attempt_progress(
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id} "
f"completed_batches={coordination_status.completed_batches} "
f"total_batches={total_batches} "
f"total_batches={coordination_status.total_batches or '?'} "
f"total_docs={coordination_status.total_docs} "
f"total_failures={coordination_status.total_failures}"
f"elapsed={(current_db_time - attempt.time_created).seconds}"
@@ -415,7 +410,7 @@ def check_indexing_completion(
logger.info(
f"Indexing status: "
f"indexing_completed={indexing_completed} "
f"batches_processed={batches_processed}/{batches_total if batches_total is not None else '?'} "
f"batches_processed={batches_processed}/{batches_total or '?'} "
f"total_docs={coordination_status.total_docs} "
f"total_chunks={coordination_status.total_chunks} "
f"total_failures={coordination_status.total_failures}"

View File

@@ -1,19 +1,8 @@
import threading
import time
from collections.abc import Callable
from collections.abc import Generator
from queue import Empty
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
from onyx.tools.models import ToolCallInfo
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import wait_on_background
# Type alias for search doc deduplication key
# Simple key: just document_id (str)
@@ -159,114 +148,3 @@ class ChatStateContainer:
"""Thread-safe getter for emitted citations (returns a copy)."""
with self._lock:
return self._emitted_citations.copy()
def run_chat_loop_with_state_containers(
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
completion_callback: Callable[[ChatStateContainer], None],
is_connected: Callable[[], bool],
emitter: Emitter,
state_container: ChatStateContainer,
) -> Generator[Packet, None]:
"""
Explicit wrapper function that runs a function in a background thread
with event streaming capabilities.
The wrapped function should accept emitter as first arg and use it to emit
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
Args:
func: The function to wrap (should accept emitter and state_container as first and second args)
completion_callback: Callback function to call when the function completes
emitter: Emitter instance for sending packets
state_container: ChatStateContainer instance for accumulating state
is_connected: Callable that returns False when stop signal is set
Usage:
packets = run_chat_loop_with_state_containers(
my_func,
completion_callback=completion_callback,
emitter=emitter,
state_container=state_container,
is_connected=check_func,
)
for packet in packets:
# Process packets
pass
"""
def run_with_exception_capture() -> None:
try:
chat_loop_func(emitter, state_container)
except Exception as e:
# If execution fails, emit an exception packet
emitter.emit(
Packet(
placement=Placement(turn_index=0),
obj=PacketException(type="error", exception=e),
)
)
# Run the function in a background thread
thread = run_in_background(run_with_exception_capture)
pkt: Packet | None = None
last_turn_index = 0 # Track the highest turn_index seen for stop packet
last_cancel_check = time.monotonic()
cancel_check_interval = 0.3 # Check for cancellation every 300ms
try:
while True:
# Poll queue with 300ms timeout for natural stop signal checking
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
try:
pkt = emitter.bus.get(timeout=0.3)
except Empty:
if not is_connected():
# Stop signal detected
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = time.monotonic()
continue
if pkt is not None:
# Track the highest turn_index for the stop packet
if pkt.placement and pkt.placement.turn_index > last_turn_index:
last_turn_index = pkt.placement.turn_index
if isinstance(pkt.obj, OverallStop):
yield pkt
break
elif isinstance(pkt.obj, PacketException):
raise pkt.obj.exception
else:
yield pkt
# Check for cancellation periodically even when packets are flowing
# This ensures stop signal is checked during active streaming
current_time = time.monotonic()
if current_time - last_cancel_check >= cancel_check_interval:
if not is_connected():
# Stop signal detected during streaming
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = current_time
finally:
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
# Skip waiting if user disconnected to exit quickly.
if is_connected():
wait_on_background(thread)
try:
completion_callback(state_container)
except Exception as e:
emitter.emit(
Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=PacketException(type="error", exception=e),
)
)

View File

@@ -1,19 +1,40 @@
import threading
from queue import Queue
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
class Emitter:
"""Use this inside tools to emit arbitrary UI progress."""
"""Routes packets from LLM/tool execution to the ``_run_models`` drain loop.
def __init__(self, bus: Queue):
self.bus = bus
Tags every packet with ``model_index`` and places it on ``merged_queue``
as a ``(model_idx, packet)`` tuple for ordered consumption downstream.
Args:
merged_queue: Shared queue owned by ``_run_models``.
model_idx: Index embedded in packet placements (``0`` for N=1 runs).
drain_done: Optional event set by ``_run_models`` when the drain loop
exits early (e.g. HTTP disconnect). When set, ``emit`` returns
immediately so worker threads can exit fast.
"""
def __init__(
self,
merged_queue: Queue[tuple[int, Packet | Exception | object]],
model_idx: int = 0,
drain_done: threading.Event | None = None,
) -> None:
self._model_idx = model_idx
self._merged_queue = merged_queue
self._drain_done = drain_done
def emit(self, packet: Packet) -> None:
self.bus.put(packet) # Thread-safe
def get_default_emitter() -> Emitter:
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
return emitter
if self._drain_done is not None and self._drain_done.is_set():
return
base = packet.placement or Placement(turn_index=0)
tagged = Packet(
placement=base.model_copy(update={"model_index": self._model_idx}),
obj=packet.obj,
)
self._merged_queue.put((self._model_idx, tagged))

File diff suppressed because it is too large Load Diff

View File

@@ -1079,6 +1079,7 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
HOOK_ENABLED = os.environ.get("HOOK_ENABLED", "").lower() == "true"
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"

View File

@@ -212,7 +212,6 @@ class DocumentSource(str, Enum):
PRODUCTBOARD = "productboard"
FILE = "file"
CODA = "coda"
CANVAS = "canvas"
NOTION = "notion"
ZULIP = "zulip"
LINEAR = "linear"
@@ -278,7 +277,6 @@ class NotificationType(str, Enum):
RELEASE_NOTES = "release_notes"
ASSISTANT_FILES_READY = "assistant_files_ready"
FEATURE_ANNOUNCEMENT = "feature_announcement"
USER_GROUP_ASSIGNMENT_FAILED = "user_group_assignment_failed"
class BlobType(str, Enum):
@@ -674,7 +672,6 @@ DocumentSourceDescription: dict[DocumentSource, str] = {
DocumentSource.SLAB: "slab data",
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
DocumentSource.FILE: "files",
DocumentSource.CANVAS: "canvas lms - courses, pages, assignments, and announcements",
DocumentSource.CODA: "coda - team workspace with docs, tables, and pages",
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
project management, and collaboration tools into a single, customizable platform",

View File

@@ -1,32 +0,0 @@
"""
Permissioning / AccessControl logic for Canvas courses.
CE stub — returns None (no permissions). The EE implementation is loaded
at runtime via ``fetch_versioned_implementation``.
"""
from collections.abc import Callable
from typing import cast
from onyx.access.models import ExternalAccess
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version
def get_course_permissions(
canvas_client: CanvasApiClient,
course_id: int,
) -> ExternalAccess | None:
if not global_version.is_ee_version():
return None
ee_get_course_permissions = cast(
Callable[[CanvasApiClient, int], ExternalAccess | None],
fetch_versioned_implementation(
"onyx.external_permissions.canvas.access",
"get_course_permissions",
),
)
return ee_get_course_permissions(canvas_client, course_id)

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import logging
import re
from collections.abc import Iterator
from typing import Any
from urllib.parse import urlparse
@@ -191,22 +190,3 @@ class CanvasApiClient:
if clean_endpoint:
final_url += "/" + clean_endpoint
return final_url
def paginate(
self,
endpoint: str,
params: dict[str, Any] | None = None,
) -> Iterator[list[Any]]:
"""Yield each page of results, following Link-header pagination.
Makes the first request with endpoint + params, then follows
next_url from Link headers for subsequent pages.
"""
response, next_url = self.get(endpoint, params=params)
while True:
if not response:
break
yield response
if not next_url:
break
response, next_url = self.get(full_url=next_url)

View File

@@ -1,82 +1,17 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Literal
from typing import NoReturn
from typing import TypeAlias
from pydantic import BaseModel
from retry import retry
from typing_extensions import override
from onyx.access.models import ExternalAccess
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.canvas.access import get_course_permissions
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
"""Map Canvas API errors to connector framework exceptions."""
if e.status_code == 401:
raise CredentialExpiredError(
"Canvas API token is invalid or expired (HTTP 401)."
)
elif e.status_code == 403:
raise InsufficientPermissionsError(
"Canvas API token does not have sufficient permissions (HTTP 403)."
)
elif e.status_code == 429:
raise ConnectorValidationError(
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
)
elif e.status_code >= 500:
raise UnexpectedValidationError(
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
)
else:
raise ConnectorValidationError(
f"Canvas API error (status={e.status_code}): {e}"
)
class CanvasCourse(BaseModel):
id: int
name: str | None = None
course_code: str | None = None
created_at: str | None = None
workflow_state: str | None = None
@classmethod
def from_api(cls, payload: dict[str, Any]) -> "CanvasCourse":
return cls(
id=payload["id"],
name=payload.get("name"),
course_code=payload.get("course_code"),
created_at=payload.get("created_at"),
workflow_state=payload.get("workflow_state"),
)
name: str
course_code: str
created_at: str
workflow_state: str
class CanvasPage(BaseModel):
@@ -84,22 +19,10 @@ class CanvasPage(BaseModel):
url: str
title: str
body: str | None = None
created_at: str | None = None
updated_at: str | None = None
created_at: str
updated_at: str
course_id: int
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasPage":
return cls(
page_id=payload["page_id"],
url=payload["url"],
title=payload["title"],
body=payload.get("body"),
created_at=payload.get("created_at"),
updated_at=payload.get("updated_at"),
course_id=course_id,
)
class CanvasAssignment(BaseModel):
id: int
@@ -107,23 +30,10 @@ class CanvasAssignment(BaseModel):
description: str | None = None
html_url: str
course_id: int
created_at: str | None = None
updated_at: str | None = None
created_at: str
updated_at: str
due_at: str | None = None
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAssignment":
return cls(
id=payload["id"],
name=payload["name"],
description=payload.get("description"),
html_url=payload["html_url"],
course_id=course_id,
created_at=payload.get("created_at"),
updated_at=payload.get("updated_at"),
due_at=payload.get("due_at"),
)
class CanvasAnnouncement(BaseModel):
id: int
@@ -133,17 +43,6 @@ class CanvasAnnouncement(BaseModel):
posted_at: str | None = None
course_id: int
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAnnouncement":
return cls(
id=payload["id"],
title=payload["title"],
message=payload.get("message"),
html_url=payload["html_url"],
posted_at=payload.get("posted_at"),
course_id=course_id,
)
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
@@ -173,286 +72,3 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
self.current_course_index += 1
self.stage = "pages"
self.next_url = None
class CanvasConnector(
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
SlimConnectorWithPermSync,
):
def __init__(
self,
canvas_base_url: str,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.canvas_base_url = canvas_base_url.rstrip("/").removesuffix("/api/v1")
self.batch_size = batch_size
self._canvas_client: CanvasApiClient | None = None
self._course_permissions_cache: dict[int, ExternalAccess | None] = {}
@property
def canvas_client(self) -> CanvasApiClient:
if self._canvas_client is None:
raise ConnectorMissingCredentialError("Canvas")
return self._canvas_client
def _get_course_permissions(self, course_id: int) -> ExternalAccess | None:
"""Get course permissions with caching."""
if course_id not in self._course_permissions_cache:
self._course_permissions_cache[course_id] = get_course_permissions(
canvas_client=self.canvas_client,
course_id=course_id,
)
return self._course_permissions_cache[course_id]
@retry(tries=3, delay=1, backoff=2)
def _list_courses(self) -> list[CanvasCourse]:
"""Fetch all courses accessible to the authenticated user."""
logger.debug("Fetching Canvas courses")
courses: list[CanvasCourse] = []
for page in self.canvas_client.paginate(
"courses", params={"per_page": "100", "state[]": "available"}
):
courses.extend(CanvasCourse.from_api(c) for c in page)
return courses
@retry(tries=3, delay=1, backoff=2)
def _list_pages(self, course_id: int) -> list[CanvasPage]:
"""Fetch all pages for a given course."""
logger.debug(f"Fetching pages for course {course_id}")
pages: list[CanvasPage] = []
for page in self.canvas_client.paginate(
f"courses/{course_id}/pages",
params={"per_page": "100", "include[]": "body", "published": "true"},
):
pages.extend(CanvasPage.from_api(p, course_id=course_id) for p in page)
return pages
@retry(tries=3, delay=1, backoff=2)
def _list_assignments(self, course_id: int) -> list[CanvasAssignment]:
"""Fetch all assignments for a given course."""
logger.debug(f"Fetching assignments for course {course_id}")
assignments: list[CanvasAssignment] = []
for page in self.canvas_client.paginate(
f"courses/{course_id}/assignments",
params={"per_page": "100", "published": "true"},
):
assignments.extend(
CanvasAssignment.from_api(a, course_id=course_id) for a in page
)
return assignments
@retry(tries=3, delay=1, backoff=2)
def _list_announcements(self, course_id: int) -> list[CanvasAnnouncement]:
"""Fetch all announcements for a given course."""
logger.debug(f"Fetching announcements for course {course_id}")
announcements: list[CanvasAnnouncement] = []
for page in self.canvas_client.paginate(
"announcements",
params={
"per_page": "100",
"context_codes[]": f"course_{course_id}",
"active_only": "true",
},
):
announcements.extend(
CanvasAnnouncement.from_api(a, course_id=course_id) for a in page
)
return announcements
def _build_document(
self,
doc_id: str,
link: str,
text: str,
semantic_identifier: str,
doc_updated_at: datetime | None,
course_id: int,
doc_type: str,
) -> Document:
"""Build a Document with standard Canvas fields."""
return Document(
id=doc_id,
sections=cast(
list[TextSection | ImageSection],
[TextSection(link=link, text=text)],
),
source=DocumentSource.CANVAS,
semantic_identifier=semantic_identifier,
doc_updated_at=doc_updated_at,
metadata={"course_id": str(course_id), "type": doc_type},
)
def _convert_page_to_document(self, page: CanvasPage) -> Document:
"""Convert a Canvas page to a Document."""
link = f"{self.canvas_base_url}/courses/{page.course_id}/pages/{page.url}"
text_parts = [page.title]
body_text = parse_html_page_basic(page.body) if page.body else ""
if body_text:
text_parts.append(body_text)
doc_updated_at = (
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
timezone.utc
)
if page.updated_at
else None
)
document = self._build_document(
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
link=link,
text="\n\n".join(text_parts),
semantic_identifier=page.title or f"Page {page.page_id}",
doc_updated_at=doc_updated_at,
course_id=page.course_id,
doc_type="page",
)
return document
def _convert_assignment_to_document(self, assignment: CanvasAssignment) -> Document:
"""Convert a Canvas assignment to a Document."""
text_parts = [assignment.name]
desc_text = (
parse_html_page_basic(assignment.description)
if assignment.description
else ""
)
if desc_text:
text_parts.append(desc_text)
if assignment.due_at:
due_dt = datetime.fromisoformat(
assignment.due_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
doc_updated_at = (
datetime.fromisoformat(
assignment.updated_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if assignment.updated_at
else None
)
document = self._build_document(
doc_id=f"canvas-assignment-{assignment.course_id}-{assignment.id}",
link=assignment.html_url,
text="\n\n".join(text_parts),
semantic_identifier=assignment.name or f"Assignment {assignment.id}",
doc_updated_at=doc_updated_at,
course_id=assignment.course_id,
doc_type="assignment",
)
return document
def _convert_announcement_to_document(
self, announcement: CanvasAnnouncement
) -> Document:
"""Convert a Canvas announcement to a Document."""
text_parts = [announcement.title]
msg_text = (
parse_html_page_basic(announcement.message) if announcement.message else ""
)
if msg_text:
text_parts.append(msg_text)
doc_updated_at = (
datetime.fromisoformat(
announcement.posted_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if announcement.posted_at
else None
)
document = self._build_document(
doc_id=f"canvas-announcement-{announcement.course_id}-{announcement.id}",
link=announcement.html_url,
text="\n\n".join(text_parts),
semantic_identifier=announcement.title or f"Announcement {announcement.id}",
doc_updated_at=doc_updated_at,
course_id=announcement.course_id,
doc_type="announcement",
)
return document
@override
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load and validate Canvas credentials."""
access_token = credentials.get("canvas_access_token")
if not access_token:
raise ConnectorMissingCredentialError("Canvas")
try:
client = CanvasApiClient(
bearer_token=access_token,
canvas_base_url=self.canvas_base_url,
)
client.get("courses", params={"per_page": "1"})
except ValueError as e:
raise ConnectorValidationError(f"Invalid Canvas base URL: {e}")
except OnyxError as e:
_handle_canvas_api_error(e)
self._canvas_client = client
return None
@override
def validate_connector_settings(self) -> None:
"""Validate Canvas connector settings by testing API access."""
try:
self.canvas_client.get("courses", params={"per_page": "1"})
logger.info("Canvas connector settings validated successfully")
except OnyxError as e:
_handle_canvas_api_error(e)
except ConnectorMissingCredentialError:
raise
except Exception as exc:
raise UnexpectedValidationError(
f"Unexpected error during Canvas settings validation: {exc}"
)
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
# TODO(benwu408): implemented in PR4 (perm sync)
raise NotImplementedError

View File

@@ -11,13 +11,11 @@ from discord import Client
from discord.channel import TextChannel
from discord.channel import Thread
from discord.enums import MessageType
from discord.errors import LoginFailure
from discord.flags import Intents
from discord.message import Message as DiscordMessage
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import CredentialInvalidError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -211,19 +209,8 @@ def _manage_async_retrieval(
intents = Intents.default()
intents.message_content = True
async with Client(intents=intents) as discord_client:
start_task = asyncio.create_task(discord_client.start(token))
ready_task = asyncio.create_task(discord_client.wait_until_ready())
done, _ = await asyncio.wait(
{start_task, ready_task},
return_when=asyncio.FIRST_COMPLETED,
)
# start() runs indefinitely once connected, so it only lands
# in `done` when login/connection failed — propagate the error.
if start_task in done:
ready_task.cancel()
start_task.result()
asyncio.create_task(discord_client.start(token))
await discord_client.wait_until_ready()
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
discord_client=discord_client,
@@ -289,19 +276,6 @@ class DiscordConnector(PollConnector, LoadConnector):
self._discord_bot_token = credentials["discord_bot_token"]
return None
def validate_connector_settings(self) -> None:
loop = asyncio.new_event_loop()
try:
client = Client(intents=Intents.default())
try:
loop.run_until_complete(client.login(self.discord_bot_token))
except LoginFailure as e:
raise CredentialInvalidError(f"Invalid Discord bot token: {e}")
finally:
loop.run_until_complete(client.close())
finally:
loop.close()
def _manage_doc_batching(
self,
start: datetime | None = None,

View File

@@ -72,10 +72,6 @@ CONNECTOR_CLASS_MAP = {
module_path="onyx.connectors.coda.connector",
class_name="CodaConnector",
),
DocumentSource.CANVAS: ConnectorMapping(
module_path="onyx.connectors.canvas.connector",
class_name="CanvasConnector",
),
DocumentSource.NOTION: ConnectorMapping(
module_path="onyx.connectors.notion.connector",
class_name="NotionConnector",

View File

@@ -11,19 +11,14 @@ from onyx.auth.api_key import ApiKeyDescriptor
from onyx.auth.api_key import build_displayable_api_key
from onyx.auth.api_key import generate_api_key
from onyx.auth.api_key import hash_api_key
from onyx.auth.schemas import UserRole
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.enums import AccountType
from onyx.db.models import ApiKey
from onyx.db.models import User
from onyx.server.api_key.models import APIKeyArgs
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def get_api_key_email_pattern() -> str:
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
@@ -92,7 +87,6 @@ def insert_api_key(
is_superuser=False,
is_verified=True,
role=api_key_args.role,
account_type=AccountType.SERVICE_ACCOUNT,
)
db_session.add(api_key_user_row)
@@ -105,21 +99,7 @@ def insert_api_key(
)
db_session.add(api_key_row)
# Assign the API key virtual user to the appropriate default group
# before commit so everything is atomic.
# LIMITED role service accounts should have no group membership.
# Late import to avoid circular dependency (api_key <- users <- api_key).
if api_key_args.role != UserRole.LIMITED:
from onyx.db.users import assign_user_to_default_groups__no_commit
assign_user_to_default_groups__no_commit(
db_session,
api_key_user_row,
is_admin=(api_key_args.role == UserRole.ADMIN),
)
db_session.commit()
return ApiKeyDescriptor(
api_key_id=api_key_row.id,
api_key_role=api_key_user_row.role,

View File

@@ -8,6 +8,7 @@ from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import nullsfirst
from sqlalchemy import or_
@@ -131,47 +132,32 @@ def get_chat_sessions_by_user(
if before is not None:
stmt = stmt.where(ChatSession.time_updated < before)
if limit:
stmt = stmt.limit(limit)
if project_id is not None:
stmt = stmt.where(ChatSession.project_id == project_id)
elif only_non_project_chats:
stmt = stmt.where(ChatSession.project_id.is_(None))
# When filtering out failed chats, we apply the limit in Python after
# filtering rather than in SQL, since the post-filter may remove rows.
if limit and include_failed_chats:
stmt = stmt.limit(limit)
if not include_failed_chats:
non_system_message_exists_subq = (
exists()
.where(ChatMessage.chat_session_id == ChatSession.id)
.where(ChatMessage.message_type != MessageType.SYSTEM)
.correlate(ChatSession)
)
# Leeway for newly created chats that don't have messages yet
time = datetime.now(timezone.utc) - timedelta(minutes=5)
recently_created = ChatSession.time_created >= time
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
result = db_session.execute(stmt)
chat_sessions = list(result.scalars().all())
chat_sessions = result.scalars().all()
if not include_failed_chats and chat_sessions:
# Filter out "failed" sessions (those with only SYSTEM messages)
# using a separate efficient query instead of a correlated EXISTS
# subquery, which causes full sequential scans of chat_message.
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
if session_ids:
valid_session_ids_stmt = (
select(ChatMessage.chat_session_id)
.where(ChatMessage.chat_session_id.in_(session_ids))
.where(ChatMessage.message_type != MessageType.SYSTEM)
.distinct()
)
valid_session_ids = set(
db_session.execute(valid_session_ids_stmt).scalars().all()
)
chat_sessions = [
cs
for cs in chat_sessions
if cs.time_created >= leeway or cs.id in valid_session_ids
]
if limit:
chat_sessions = chat_sessions[:limit]
return chat_sessions
return list(chat_sessions)
def delete_orphaned_search_docs(db_session: Session) -> None:
@@ -631,6 +617,92 @@ def reserve_message_id(
return empty_message
def reserve_multi_model_message_ids(
db_session: Session,
chat_session_id: UUID,
parent_message_id: int,
model_display_names: list[str],
) -> list[ChatMessage]:
"""Reserve N assistant message placeholders for multi-model parallel streaming.
All messages share the same parent (the user message). The parent's
latest_child_message_id points to the LAST reserved message so that the
default history-chain walker picks it up.
"""
reserved: list[ChatMessage] = []
for display_name in model_display_names:
msg = ChatMessage(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
latest_child_message_id=None,
message="Response was terminated prior to completion, try regenerating.",
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
message_type=MessageType.ASSISTANT,
model_display_name=display_name,
)
db_session.add(msg)
reserved.append(msg)
# Flush to assign IDs without committing yet
db_session.flush()
# Point parent's latest_child to the last reserved message
parent = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == parent_message_id)
.first()
)
if parent:
parent.latest_child_message_id = reserved[-1].id
db_session.commit()
return reserved
def set_preferred_response(
db_session: Session,
user_message_id: int,
preferred_assistant_message_id: int,
) -> None:
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
Also advances ``latest_child_message_id`` so the preferred response becomes
the active branch for any subsequent messages in the conversation.
Args:
db_session: Active database session.
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
preferred response is being set.
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
Raises:
ValueError: If either message is not found, if ``user_message_id`` does not
refer to a USER message, or if the assistant message is not a direct child
of the user message.
"""
user_msg = db_session.get(ChatMessage, user_message_id)
if user_msg is None:
raise ValueError(f"User message {user_message_id} not found")
if user_msg.message_type != MessageType.USER:
raise ValueError(f"Message {user_message_id} is not a user message")
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
if assistant_msg is None:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} not found"
)
if assistant_msg.parent_message_id != user_message_id:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} is not a child "
f"of user message {user_message_id}"
)
user_msg.preferred_response_id = preferred_assistant_message_id
user_msg.latest_child_message_id = preferred_assistant_message_id
db_session.commit()
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
@@ -853,6 +925,8 @@ def translate_db_message_to_chat_message_detail(
error=chat_message.error,
current_feedback=current_feedback,
processing_duration_seconds=chat_message.processing_duration_seconds,
preferred_response_id=chat_message.preferred_response_id,
model_display_name=chat_message.model_display_name,
)
return chat_msg_detail

View File

@@ -13,19 +13,19 @@ class AccountType(str, PyEnum):
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
"""
STANDARD = "STANDARD"
BOT = "BOT"
EXT_PERM_USER = "EXT_PERM_USER"
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
ANONYMOUS = "ANONYMOUS"
STANDARD = "standard"
BOT = "bot"
EXT_PERM_USER = "ext_perm_user"
SERVICE_ACCOUNT = "service_account"
ANONYMOUS = "anonymous"
class GrantSource(str, PyEnum):
"""How a permission grant was created."""
USER = "USER"
SCIM = "SCIM"
SYSTEM = "SYSTEM"
USER = "user"
SCIM = "scim"
SYSTEM = "system"
class IndexingStatus(str, PyEnum):
@@ -215,7 +215,6 @@ class UserFileStatus(str, PyEnum):
PROCESSING = "PROCESSING"
INDEXING = "INDEXING"
COMPLETED = "COMPLETED"
SKIPPED = "SKIPPED"
FAILED = "FAILED"
CANCELED = "CANCELED"
DELETING = "DELETING"

View File

@@ -305,11 +305,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
)
account_type: Mapped[AccountType] = mapped_column(
Enum(AccountType, native_enum=False),
nullable=False,
default=AccountType.STANDARD,
server_default="STANDARD",
account_type: Mapped[AccountType | None] = mapped_column(
Enum(AccountType, native_enum=False), nullable=True
)
"""
@@ -356,13 +353,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
postgresql.JSONB(), nullable=True, default=None
)
effective_permissions: Mapped[list[str]] = mapped_column(
postgresql.JSONB(),
nullable=False,
default=list,
server_default=text("'[]'::jsonb"),
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
TIMESTAMPAware(timezone=True), nullable=True
)
@@ -4026,12 +4016,7 @@ class PermissionGrant(Base):
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
)
permission: Mapped[Permission] = mapped_column(
Enum(
Permission,
native_enum=False,
values_callable=lambda x: [e.value for e in x],
),
nullable=False,
Enum(Permission, native_enum=False), nullable=False
)
grant_source: Mapped[GrantSource] = mapped_column(
Enum(GrantSource, native_enum=False), nullable=False

View File

@@ -3,7 +3,6 @@ from datetime import timezone
from uuid import UUID
from sqlalchemy import cast
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import insert
@@ -91,18 +90,9 @@ def get_notifications(
notif_type: NotificationType | None = None,
include_dismissed: bool = True,
) -> list[Notification]:
if user is None:
user_filter = Notification.user_id.is_(None)
elif user.role == UserRole.ADMIN:
# Admins see their own notifications AND admin-targeted ones (user_id IS NULL)
user_filter = or_(
Notification.user_id == user.id,
Notification.user_id.is_(None),
)
else:
user_filter = Notification.user_id == user.id
query = select(Notification).where(user_filter)
query = select(Notification).where(
Notification.user_id == user.id if user else Notification.user_id.is_(None)
)
if not include_dismissed:
query = query.where(Notification.dismissed.is_(False))
if notif_type:

View File

@@ -1,97 +0,0 @@
"""
DB operations for recomputing user effective_permissions.
These live in onyx/db/ (not onyx/auth/) because they are pure DB operations
that query PermissionGrant rows and update the User.effective_permissions
JSONB column. Keeping them here avoids circular imports when called from
other onyx/db/ modules such as users.py.
"""
from collections import defaultdict
from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.db.models import PermissionGrant
from onyx.db.models import User
from onyx.db.models import User__UserGroup
def recompute_user_permissions__no_commit(user_id: UUID, db_session: Session) -> None:
"""Recompute a single user's granted permissions from their group grants.
Stores only directly granted permissions — implication expansion
happens at read time via get_effective_permissions().
Does NOT commit — caller must commit the session.
"""
stmt = (
select(PermissionGrant.permission)
.join(
User__UserGroup,
PermissionGrant.group_id == User__UserGroup.user_group_id,
)
.where(
User__UserGroup.user_id == user_id,
PermissionGrant.is_deleted.is_(False),
)
)
rows = db_session.execute(stmt).scalars().all()
# sorted for consistent ordering in DB — easier to read when debugging
granted = sorted({p.value for p in rows})
db_session.execute(
update(User).where(User.id == user_id).values(effective_permissions=granted)
)
def recompute_permissions_for_group__no_commit(
group_id: int, db_session: Session
) -> None:
"""Recompute granted permissions for all users in a group.
Does NOT commit — caller must commit the session.
"""
user_ids: list[UUID] = list(
db_session.execute(
select(User__UserGroup.user_id).where(
User__UserGroup.user_group_id == group_id,
User__UserGroup.user_id.isnot(None),
)
)
.scalars()
.all()
)
if not user_ids:
return
# Single query to fetch ALL permissions for these users across ALL their
# groups (a user may belong to multiple groups with different grants).
rows = db_session.execute(
select(User__UserGroup.user_id, PermissionGrant.permission)
.join(
PermissionGrant,
PermissionGrant.group_id == User__UserGroup.user_group_id,
)
.where(
User__UserGroup.user_id.in_(user_ids),
PermissionGrant.is_deleted.is_(False),
)
).all()
# Group permissions by user; users with no grants get an empty set.
perms_by_user: dict[UUID, set[str]] = defaultdict(set)
for uid in user_ids:
perms_by_user[uid] # ensure every user has an entry
for uid, perm in rows:
perms_by_user[uid].add(perm.value)
for uid, perms in perms_by_user.items():
db_session.execute(
update(User)
.where(User.id == uid)
.values(effective_permissions=sorted(perms))
)

View File

@@ -7,7 +7,6 @@ from fastapi import HTTPException
from fastapi import UploadFile
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from sqlalchemy import func
from sqlalchemy.orm import Session
from starlette.background import BackgroundTasks
@@ -18,7 +17,6 @@ from onyx.configs.constants import FileOrigin
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.enums import UserFileStatus
from onyx.db.models import Project__UserFile
from onyx.db.models import User
from onyx.db.models import UserFile
@@ -36,19 +34,9 @@ class CategorizedFilesResult(BaseModel):
user_files: list[UserFile]
rejected_files: list[RejectedFile]
id_to_temp_id: dict[str, str]
# Filenames that should be stored but not indexed.
skip_indexing_filenames: set[str] = Field(default_factory=set)
# Allow SQLAlchemy ORM models inside this result container
model_config = ConfigDict(arbitrary_types_allowed=True)
@property
def indexable_files(self) -> list[UserFile]:
return [
uf
for uf in self.user_files
if (uf.name or "") not in self.skip_indexing_filenames
]
def build_hashed_file_key(file: UploadFile) -> str:
name_prefix = (file.filename or "")[:50]
@@ -82,7 +70,6 @@ def create_user_files(
)
if new_temp_id is not None:
id_to_temp_id[str(new_id)] = new_temp_id
should_skip = (file.filename or "") in categorized_files.skip_indexing
new_file = UserFile(
id=new_id,
user_id=user.id,
@@ -94,7 +81,6 @@ def create_user_files(
link_url=link_url,
content_type=file.content_type,
file_type=file.content_type,
status=UserFileStatus.SKIPPED if should_skip else UserFileStatus.PROCESSING,
last_accessed_at=datetime.datetime.now(datetime.timezone.utc),
)
# Persist the UserFile first to satisfy FK constraints for association table
@@ -112,7 +98,6 @@ def create_user_files(
user_files=user_files,
rejected_files=rejected_files,
id_to_temp_id=id_to_temp_id,
skip_indexing_filenames=categorized_files.skip_indexing,
)
@@ -138,7 +123,6 @@ def upload_files_to_user_files_with_indexing(
user_files = categorized_files_result.user_files
rejected_files = categorized_files_result.rejected_files
id_to_temp_id = categorized_files_result.id_to_temp_id
indexable_files = categorized_files_result.indexable_files
# Trigger per-file processing immediately for the current tenant
tenant_id = get_current_tenant_id()
for rejected_file in rejected_files:
@@ -150,12 +134,12 @@ def upload_files_to_user_files_with_indexing(
from onyx.background.task_utils import drain_processing_loop
background_tasks.add_task(drain_processing_loop, tenant_id)
for user_file in indexable_files:
for user_file in user_files:
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
else:
from onyx.background.celery.versioned_apps.client import app as client_app
for user_file in indexable_files:
for user_file in user_files:
task = client_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
@@ -171,7 +155,6 @@ def upload_files_to_user_files_with_indexing(
user_files=user_files,
rejected_files=rejected_files,
id_to_temp_id=id_to_temp_id,
skip_indexing_filenames=categorized_files_result.skip_indexing_filenames,
)

View File

@@ -19,7 +19,6 @@ from onyx.auth.schemas import UserRole
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.db.enums import AccountType
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__User
from onyx.db.models import Persona
@@ -28,11 +27,8 @@ from onyx.db.models import SamlAccount
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
def validate_user_role_update(
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
@@ -302,7 +298,6 @@ def _generate_slack_user(email: str) -> User:
email=email,
hashed_password=hashed_pass,
role=UserRole.SLACK_USER,
account_type=AccountType.BOT,
)
@@ -313,7 +308,6 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
# If the user is an external permissioned user, we update it to a slack user
if user.role == UserRole.EXT_PERM_USER:
user.role = UserRole.SLACK_USER
user.account_type = AccountType.BOT
db_session.commit()
return user
@@ -350,7 +344,6 @@ def _generate_ext_permissioned_user(email: str) -> User:
email=email,
hashed_password=hashed_pass,
role=UserRole.EXT_PERM_USER,
account_type=AccountType.EXT_PERM_USER,
)
@@ -382,81 +375,6 @@ def batch_add_ext_perm_user_if_not_exists(
return all_users
def assign_user_to_default_groups__no_commit(
db_session: Session,
user: User,
is_admin: bool = False,
) -> None:
"""Assign a newly created user to the appropriate default group.
Does NOT commit — callers must commit the session themselves so that
group assignment can be part of the same transaction as user creation.
Args:
is_admin: If True, assign to Admin default group; otherwise Basic.
Callers determine this from their own context (e.g. user_count,
admin email list, explicit choice). Defaults to False (Basic).
"""
if user.account_type in (
AccountType.BOT,
AccountType.EXT_PERM_USER,
AccountType.ANONYMOUS,
):
return
target_group_name = "Admin" if is_admin else "Basic"
default_group = (
db_session.query(UserGroup)
.filter(
UserGroup.name == target_group_name,
UserGroup.is_default.is_(True),
)
.first()
)
if default_group is None:
raise RuntimeError(
f"Default group '{target_group_name}' not found. "
f"Cannot assign user {user.email} to a group. "
f"Ensure the seed_default_groups migration has run."
)
# Check if the user is already in the group
existing = (
db_session.query(User__UserGroup)
.filter(
User__UserGroup.user_id == user.id,
User__UserGroup.user_group_id == default_group.id,
)
.first()
)
if existing is not None:
return
savepoint = db_session.begin_nested()
try:
db_session.add(
User__UserGroup(
user_id=user.id,
user_group_id=default_group.id,
)
)
db_session.flush()
except IntegrityError:
# Race condition: another transaction inserted this membership
# between our SELECT and INSERT. The savepoint isolates the failure
# so the outer transaction (user creation) stays intact.
savepoint.rollback()
return
from onyx.db.permissions import recompute_user_permissions__no_commit
recompute_user_permissions__no_commit(user.id, db_session)
logger.info(f"Assigned user {user.email} to default group '{default_group.name}'")
def delete_user_from_db(
user_to_delete: User,
db_session: Session,
@@ -503,14 +421,13 @@ def delete_user_from_db(
def batch_get_user_groups(
db_session: Session,
user_ids: list[UUID],
include_default: bool = False,
) -> dict[UUID, list[tuple[int, str]]]:
"""Fetch group memberships for a batch of users in a single query.
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
if not user_ids:
return {}
stmt = (
rows = db_session.execute(
select(
User__UserGroup.user_id,
UserGroup.id,
@@ -518,11 +435,7 @@ def batch_get_user_groups(
)
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
.where(User__UserGroup.user_id.in_(user_ids))
)
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
rows = db_session.execute(stmt).all()
).all()
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
for user_id, group_id, group_name in rows:

View File

@@ -932,7 +932,7 @@ class OpenSearchIndexClient(OpenSearchClient):
def search_for_document_ids(
self,
body: dict[str, Any],
search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN,
search_type: OpenSearchSearchType = OpenSearchSearchType.DOCUMENT_IDS,
) -> list[str]:
"""Searches the index and returns only document chunk IDs.

View File

@@ -60,7 +60,8 @@ class OpenSearchSearchType(str, Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
RANDOM = "random"
DOC_ID_RETRIEVAL = "doc_id_retrieval"
ID_RETRIEVAL = "id_retrieval"
DOCUMENT_IDS = "document_ids"
UNKNOWN = "unknown"

View File

@@ -928,7 +928,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_hits = self._client.search(
body=query_body,
search_pipeline_id=None,
search_type=OpenSearchSearchType.DOC_ID_RETRIEVAL,
search_type=OpenSearchSearchType.ID_RETRIEVAL,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(

View File

@@ -15,7 +15,6 @@ PLAIN_TEXT_MIME_TYPE = "text/plain"
class OnyxMimeTypes:
IMAGE_MIME_TYPES = {"image/jpg", "image/jpeg", "image/png", "image/webp"}
CSV_MIME_TYPES = {"text/csv"}
TABULAR_MIME_TYPES = CSV_MIME_TYPES | {SPREADSHEET_MIME_TYPE}
TEXT_MIME_TYPES = {
PLAIN_TEXT_MIME_TYPE,
"text/markdown",
@@ -35,12 +34,13 @@ class OnyxMimeTypes:
PDF_MIME_TYPE,
WORD_PROCESSING_MIME_TYPE,
PRESENTATION_MIME_TYPE,
SPREADSHEET_MIME_TYPE,
"message/rfc822",
"application/epub+zip",
}
ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, TABULAR_MIME_TYPES
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES
)
EXCLUDED_IMAGE_TYPES = {
@@ -53,11 +53,6 @@ class OnyxMimeTypes:
class OnyxFileExtensions:
TABULAR_EXTENSIONS = {
".csv",
".tsv",
".xlsx",
}
PLAIN_TEXT_EXTENSIONS = {
".txt",
".md",

View File

@@ -13,14 +13,13 @@ class ChatFileType(str, Enum):
DOC = "document"
# Plain text only contain the text
PLAIN_TEXT = "plain_text"
# Tabular data files (CSV, XLSX)
TABULAR = "tabular"
CSV = "csv"
def is_text_file(self) -> bool:
return self in (
ChatFileType.PLAIN_TEXT,
ChatFileType.DOC,
ChatFileType.TABULAR,
ChatFileType.CSV,
)

View File

@@ -1,3 +1,4 @@
from onyx.configs.app_configs import HOOK_ENABLED
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from shared_configs.configs import MULTI_TENANT
@@ -6,7 +7,10 @@ from shared_configs.configs import MULTI_TENANT
def require_hook_enabled() -> None:
"""FastAPI dependency that gates all hook management endpoints.
Hooks are only available in single-tenant / self-hosted EE deployments.
Hooks are only available in single-tenant / self-hosted deployments with
HOOK_ENABLED=true explicitly set. Two layers of protection:
1. MULTI_TENANT check — rejects even if HOOK_ENABLED is accidentally set true
2. HOOK_ENABLED flag — explicit opt-in by the operator
Use as: Depends(require_hook_enabled)
"""
@@ -15,3 +19,8 @@ def require_hook_enabled() -> None:
OnyxErrorCode.SINGLE_TENANT_ONLY,
"Hooks are not available in multi-tenant deployments",
)
if not HOOK_ENABLED:
raise OnyxError(
OnyxErrorCode.ENV_VAR_GATED,
"Hooks are not enabled. Set HOOK_ENABLED=true to enable.",
)

View File

@@ -1,22 +1,79 @@
"""CE hook executor.
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
HookSkipped and HookSoftFailed are real classes kept here because
process_message.py (CE code) uses isinstance checks against them.
Usage (Celery tasks and FastAPI handlers):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
response_type=QueryProcessingResponse,
)
execute_hook is the public entry point. It dispatches to _execute_hook_impl
via fetch_versioned_implementation so that:
- CE: onyx.hooks.executor._execute_hook_impl → no-op, returns HookSkipped()
- EE: ee.onyx.hooks.executor._execute_hook_impl → real HTTP call
if isinstance(result, HookSkipped):
# no active hook configured — continue with original behavior
...
elif isinstance(result, HookSoftFailed):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is a validated Pydantic model instance (response_type)
...
is_reachable update policy
--------------------------
``is_reachable`` on the Hook row is updated selectively — only when the outcome
carries meaningful signal about physical reachability:
NetworkError (DNS, connection refused) → False (cannot reach the server)
HTTP 401 / 403 → False (api_key revoked or invalid)
TimeoutException → None (server may be slow, skip write)
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
Unknown exception → None (no signal, skip write)
Non-JSON / non-dict response → None (server responded, skip write)
Success (2xx, valid dict) → True (confirmed reachable)
None means "leave the current value unchanged" — no DB round-trip is made.
DB session design
-----------------
The executor uses three sessions:
1. Caller's session (db_session) — used only for the hook lookup read. All
needed fields are extracted from the Hook object before the HTTP call, so
the caller's session is not held open during the external HTTP request.
2. Log session — a separate short-lived session opened after the HTTP call
completes to write the HookExecutionLog row on failure. Success runs are
not recorded. Committed independently of everything else.
3. Reachable session — a second short-lived session to update is_reachable on
the Hook. Kept separate from the log session so a concurrent hook deletion
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
prevent the execution log from being written. This update is best-effort.
"""
import json
import time
from typing import Any
from typing import TypeVar
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.db.hook import create_hook_execution_log__no_commit
from onyx.db.hook import get_non_deleted_hook_by_hook_point
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.utils.logger import setup_logger
logger = setup_logger()
class HookSkipped:
@@ -30,15 +87,277 @@ class HookSoftFailed:
T = TypeVar("T", bound=BaseModel)
def _execute_hook_impl(
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
class _HttpOutcome(BaseModel):
"""Structured result of an HTTP hook call, returned by _process_response."""
is_success: bool
updated_is_reachable: (
bool | None
) # True/False = write to DB, None = unchanged (skip write)
status_code: int | None
error_message: str | None
response_payload: dict[str, Any] | None
def _lookup_hook(
db_session: Session,
hook_point: HookPoint,
) -> Hook | HookSkipped:
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
No HTTP call is made and no DB writes are performed for any HookSkipped path.
There is nothing to log and no reachability information to update.
"""
if not HOOKS_AVAILABLE:
return HookSkipped()
hook = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if hook is None or not hook.is_active:
return HookSkipped()
if not hook.endpoint_url:
return HookSkipped()
return hook
def _process_response(
*,
db_session: Session, # noqa: ARG001
hook_point: HookPoint, # noqa: ARG001
payload: dict[str, Any], # noqa: ARG001
response_type: type[T], # noqa: ARG001
) -> T | HookSkipped | HookSoftFailed:
"""CE no-op — hooks are not available without EE."""
return HookSkipped()
response: httpx.Response | None,
exc: Exception | None,
timeout: float,
) -> _HttpOutcome:
"""Process the result of an HTTP call and return a structured outcome.
Called after the client.post() try/except. If post() raised, exc is set and
response is None. Otherwise response is set and exc is None. Handles
raise_for_status(), JSON decoding, and the dict shape check.
"""
if exc is not None:
if isinstance(exc, httpx.NetworkError):
msg = f"Hook network error (endpoint unreachable): {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False,
status_code=None,
error_message=msg,
response_payload=None,
)
if isinstance(exc, httpx.TimeoutException):
msg = f"Hook timed out after {timeout}s: {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # timeout doesn't indicate unreachability
status_code=None,
error_message=msg,
response_payload=None,
)
msg = f"Hook call failed: {exc}"
logger.exception(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # unknown error — don't make assumptions
status_code=None,
error_message=msg,
response_payload=None,
)
if response is None:
raise ValueError(
"exactly one of response or exc must be non-None; both are None"
)
status_code = response.status_code
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
logger.warning(msg, exc_info=e)
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
# so the operator knows to update it. All other HTTP errors keep is_reachable
# as-is (server is up, the request just failed for application reasons).
auth_failed = e.response.status_code in (401, 403)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False if auth_failed else None,
status_code=status_code,
error_message=msg,
response_payload=None,
)
try:
response_payload = response.json()
except (json.JSONDecodeError, httpx.DecodingError) as e:
msg = f"Hook returned non-JSON response: {e}"
logger.warning(msg, exc_info=e)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
if not isinstance(response_payload, dict):
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
logger.warning(msg)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
return _HttpOutcome(
is_success=True,
updated_is_reachable=True,
status_code=status_code,
error_message=None,
response_payload=response_payload,
)
def _persist_result(
*,
hook_id: int,
outcome: _HttpOutcome,
duration_ms: int,
) -> None:
"""Write the execution log on failure and optionally update is_reachable, each
in its own session so a failure in one does not affect the other."""
# Only write the execution log on failure — success runs are not recorded.
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
# deleted between the initial lookup and here).
if not outcome.is_success:
try:
with get_session_with_current_tenant() as log_session:
create_hook_execution_log__no_commit(
db_session=log_session,
hook_id=hook_id,
is_success=False,
error_message=outcome.error_message,
status_code=outcome.status_code,
duration_ms=duration_ms,
)
log_session.commit()
except Exception:
logger.exception(
f"Failed to persist hook execution log for hook_id={hook_id}"
)
# Update is_reachable separately — best-effort, non-critical.
# None means the value is unchanged (set by the caller to skip the no-op write).
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
# concurrently deleted, so keep this isolated from the log write above.
if outcome.updated_is_reachable is not None:
try:
with get_session_with_current_tenant() as reachable_session:
update_hook__no_commit(
db_session=reachable_session,
hook_id=hook_id,
is_reachable=outcome.updated_is_reachable,
)
reachable_session.commit()
except Exception:
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def _execute_hook_inner(
hook: Hook,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
"active hooks without an endpoint_url must be rejected by _lookup_hook"
)
start = time.monotonic()
response: httpx.Response | None = None
exc: Exception | None = None
try:
api_key: str | None = (
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
)
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against response_type.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: T | None = None
if outcome.is_success and outcome.response_payload is not None:
try:
validated_model = response_type.model_validate(outcome.response_payload)
except ValidationError as e:
msg = (
f"Hook response failed validation against {response_type.__name__}: {e}"
)
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
outcome = outcome.model_copy(update={"updated_is_reachable": None})
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
if not outcome.is_success:
if fail_strategy == HookFailStrategy.HARD:
raise OnyxError(
OnyxErrorCode.HOOK_EXECUTION_FAILED,
outcome.error_message or "Hook execution failed.",
)
logger.warning(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if validated_model is None:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"validated_model is None for successful hook call (hook_id={hook_id})",
)
return validated_model
def execute_hook(
@@ -48,15 +367,25 @@ def execute_hook(
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point.
"""Execute the hook for the given hook point synchronously.
Dispatches to the versioned implementation so EE gets the real executor
and CE gets the no-op stub, without any changes at the call site.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
"""
impl = fetch_versioned_implementation("onyx.hooks.executor", "_execute_hook_impl")
return impl(
db_session=db_session,
hook_point=hook_point,
payload=payload,
response_type=response_type,
)
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
fail_strategy = hook.fail_strategy
hook_id = hook.id
try:
return _execute_hook_inner(hook, payload, response_type)
except Exception:
if fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
)
return HookSoftFailed()
raise

View File

@@ -0,0 +1,5 @@
from onyx.configs.app_configs import HOOK_ENABLED
from shared_configs.configs import MULTI_TENANT
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT

View File

@@ -8,6 +8,24 @@ from pydantic import BaseModel
class LLMOverride(BaseModel):
"""Per-request LLM settings that override persona defaults.
All fields are optional — only the fields that differ from the persona's
configured LLM need to be supplied. Used both over the wire (API requests)
and for multi-model comparison, where one override is supplied per model.
Attributes:
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
When ``None``, the persona's default provider is used.
model_version: Specific model version string (e.g. ``"gpt-4o"``).
When ``None``, the persona's default model is used.
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
persona's default temperature is used.
display_name: Human-readable label shown in the UI for this model,
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
when not set.
"""
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None

View File

@@ -77,6 +77,7 @@ from onyx.server.features.default_assistant.api import (
)
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.hierarchy.api import router as hierarchy_router
from onyx.server.features.hooks.api import router as hook_router
from onyx.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
@@ -438,7 +439,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:
@@ -454,6 +454,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
register_onyx_exception_handlers(application)
include_router_with_global_prefix_prepended(application, hook_router)
include_router_with_global_prefix_prepended(application, password_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, query_router)

View File

@@ -76,18 +76,11 @@ class CategorizedFiles(BaseModel):
acceptable: list[UploadFile] = Field(default_factory=list)
rejected: list[RejectedFile] = Field(default_factory=list)
acceptable_file_to_token_count: dict[str, int] = Field(default_factory=dict)
# Filenames within `acceptable` that should be stored but not indexed.
skip_indexing: set[str] = Field(default_factory=set)
# Allow FastAPI UploadFile instances
model_config = ConfigDict(arbitrary_types_allowed=True)
def _skip_token_threshold(extension: str) -> bool:
"""Return True if this file extension should bypass the token limit."""
return extension.lower() in OnyxFileExtensions.TABULAR_EXTENSIONS
def _apply_long_side_cap(width: int, height: int, cap: int) -> tuple[int, int]:
if max(width, height) <= cap:
return width, height
@@ -271,17 +264,7 @@ def categorize_uploaded_files(
token_count = count_tokens(
text_content, tokenizer, token_limit=token_threshold
)
exceeds_threshold = (
token_threshold is not None and token_count > token_threshold
)
if exceeds_threshold and _skip_token_threshold(extension):
# Exempt extensions (e.g. spreadsheets) are accepted
# but flagged to skip indexing — only metadata is
# injected into the LLM context.
results.acceptable.append(upload)
results.acceptable_file_to_token_count[filename] = token_count
results.skip_indexing.add(filename)
elif exceeds_threshold:
if token_threshold is not None and token_count > token_threshold:
results.rejected.append(
RejectedFile(
filename=filename,

View File

@@ -27,7 +27,6 @@ from onyx.auth.email_utils import send_user_email_invite
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import remove_user_from_invited_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.permissions import get_effective_permissions
from onyx.auth.schemas import UserRole
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
@@ -774,13 +773,6 @@ def _get_token_created_at(
return get_current_token_creation_postgres(user, db_session)
@router.get("/me/permissions", tags=PUBLIC_API_TAGS)
def get_current_user_permissions(
user: User = Depends(current_user),
) -> list[str]:
return sorted(p.value for p in get_effective_permissions(user))
@router.get("/me", tags=PUBLIC_API_TAGS)
def verify_user_logged_in(
request: Request,

View File

@@ -7,7 +7,6 @@ from uuid import UUID
from pydantic import BaseModel
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from onyx.db.models import User
@@ -42,7 +41,6 @@ class FullUserSnapshot(BaseModel):
id: UUID
email: str
role: UserRole
account_type: AccountType
is_active: bool
password_configured: bool
personal_name: str | None
@@ -62,7 +60,6 @@ class FullUserSnapshot(BaseModel):
id=user.id,
email=user.email,
role=user.role,
account_type=user.account_type,
is_active=user.is_active,
password_configured=user.password_configured,
personal_name=user.personal_name,

View File

@@ -28,6 +28,7 @@ from onyx.chat.chat_utils import extract_headers
from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_multi_model_stream
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import set_preferred_response
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.session_loading import (
@@ -570,6 +575,46 @@ def handle_send_chat_message(
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
is_multi_model = (
chat_message_req.llm_overrides is not None
and len(chat_message_req.llm_overrides) > 1
)
if is_multi_model and chat_message_req.stream:
# Narrowed here; is_multi_model already checked llm_overrides is not None
llm_overrides = chat_message_req.llm_overrides or []
def multi_model_stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in handle_multi_model_stream(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in multi-model streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(
multi_model_stream_generator(), media_type="text/event-stream"
)
if is_multi_model and not chat_message_req.stream:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
)
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -660,6 +705,30 @@ def set_message_as_latest(
)
@router.put("/set-preferred-response")
def set_preferred_response_endpoint(
request_body: SetPreferredResponseRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
"""Set the preferred assistant response for a multi-model turn."""
try:
# Ownership check: get_chat_message raises ValueError if the message
# doesn't belong to this user, preventing cross-user mutation.
get_chat_message(
chat_message_id=request_body.user_message_id,
user_id=user.id if user else None,
db_session=db_session,
)
set_preferred_response(
db_session=db_session,
user_message_id=request_body.user_message_id,
preferred_assistant_message_id=request_body.preferred_response_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,

View File

@@ -9,8 +9,8 @@ def mime_type_to_chat_file_type(mime_type: str | None) -> ChatFileType:
if mime_type in OnyxMimeTypes.IMAGE_MIME_TYPES:
return ChatFileType.IMAGE
if mime_type in OnyxMimeTypes.TABULAR_MIME_TYPES:
return ChatFileType.TABULAR
if mime_type in OnyxMimeTypes.CSV_MIME_TYPES:
return ChatFileType.CSV
if mime_type in OnyxMimeTypes.DOCUMENT_MIME_TYPES:
return ChatFileType.DOC

View File

@@ -2,11 +2,25 @@ from pydantic import BaseModel
class Placement(BaseModel):
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
"""Coordinates that identify where a streaming packet belongs in the UI.
The frontend uses these fields to route each packet to the correct turn,
tool tab, agent sub-turn, and (in multi-model mode) response column.
Attributes:
turn_index: Monotonically increasing index of the iterative reasoning block
(e.g. tool call round) within this chat message. Lower values happened first.
tab_index: Disambiguates parallel tool calls within the same turn so each
tool's output can be displayed in its own tab.
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
top-level packets; an integer for tool-within-tool output.
model_index: Which model this packet belongs to. ``0`` for single-model
responses; ``0``, ``1``, or ``2`` for multi-model comparison. ``None``
for pre-LLM setup packets (e.g. message ID info) that are yielded
before any Emitter runs.
"""
turn_index: int
# For parallel tool calls to preserve order of execution
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
model_index: int | None = None

View File

@@ -21,6 +21,7 @@ from onyx.db.notification import get_notifications
from onyx.db.notification import update_notification_last_shown
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.features.build.utils import is_onyx_craft_enabled
@@ -37,7 +38,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -98,7 +98,7 @@ def fetch_settings(
needs_reindexing=needs_reindexing,
onyx_craft_enabled=onyx_craft_enabled_for_user,
vector_db_enabled=not DISABLE_VECTOR_DB,
hooks_enabled=not MULTI_TENANT,
hooks_enabled=HOOKS_AVAILABLE,
version=onyx_version,
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
default_user_file_max_upload_size_mb=min(

View File

@@ -116,7 +116,7 @@ class UserSettings(Settings):
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
# document sets are unavailable.
vector_db_enabled: bool = True
# True when hooks are available: single-tenant EE deployments only.
# True when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
hooks_enabled: bool = False
# Application version, read from the ONYX_VERSION env var at startup.
version: str | None = None

View File

@@ -1,3 +1,4 @@
import queue
import time
from collections.abc import Callable
from typing import Any
@@ -708,7 +709,6 @@ def run_research_agent_calls(
if __name__ == "__main__":
from queue import Queue
from uuid import uuid4
from onyx.chat.chat_state import ChatStateContainer
@@ -744,8 +744,8 @@ if __name__ == "__main__":
if user is None:
raise ValueError("No users found in database. Please create a user first.")
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
emitter_queue: queue.Queue = queue.Queue()
emitter = Emitter(merged_queue=emitter_queue)
state_container = ChatStateContainer()
tool_dict = construct_tools(
@@ -792,4 +792,4 @@ if __name__ == "__main__":
print(result.intermediate_report)
print("=" * 80)
print(f"Citations: {result.citation_mapping}")
print(f"Total packets emitted: {bus.qsize()}")
print(f"Total packets emitted: {emitter_queue.qsize()}")

View File

@@ -1,5 +1,6 @@
import csv
import json
import queue
import uuid
from io import BytesIO
from io import StringIO
@@ -11,7 +12,6 @@ import requests
from requests import JSONDecodeError
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.configs.constants import FileOrigin
from onyx.file_store.file_store import get_default_file_store
from onyx.server.query_and_chat.placement import Placement
@@ -296,9 +296,9 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
# Use default emitter if none provided
# Use a discard emitter if none provided (packets go nowhere)
if emitter is None:
emitter = get_default_emitter()
emitter = Emitter(merged_queue=queue.Queue())
return [
CustomTool(
@@ -367,7 +367,7 @@ if __name__ == "__main__":
tools = build_custom_tools_from_openapi_schema_and_headers(
tool_id=0, # dummy tool id
openapi_schema=openapi_schema,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
dynamic_schema_info=None,
)

View File

@@ -1,4 +1,3 @@
import io
import json
from typing import Any
from typing import cast
@@ -10,7 +9,6 @@ from typing_extensions import override
from onyx.chat.emitter import Emitter
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_chat_file_by_id
@@ -171,13 +169,10 @@ class FileReaderTool(Tool[FileReaderToolOverrideKwargs]):
chat_file = self._load_file(file_id)
# Only PLAIN_TEXT and TABULAR are guaranteed to contain actual text bytes.
# Only PLAIN_TEXT and CSV are guaranteed to contain actual text bytes.
# DOC type in a loaded file means plaintext extraction failed and the
# content is the original binary (e.g. raw PDF/DOCX bytes).
if chat_file.file_type not in (
ChatFileType.PLAIN_TEXT,
ChatFileType.TABULAR,
):
if chat_file.file_type not in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV):
raise ToolCallException(
message=f"File {file_id} is not a text file (type={chat_file.file_type})",
llm_facing_message=(
@@ -186,19 +181,7 @@ class FileReaderTool(Tool[FileReaderToolOverrideKwargs]):
)
try:
if chat_file.file_type == ChatFileType.PLAIN_TEXT:
full_text = chat_file.content.decode("utf-8", errors="replace")
else:
full_text = (
extract_file_text(
file=io.BytesIO(chat_file.content),
file_name=chat_file.filename or "",
break_on_unprocessable=False,
)
or ""
)
except ToolCallException:
raise
full_text = chat_file.content.decode("utf-8", errors="replace")
except Exception:
raise ToolCallException(
message=f"Failed to decode file {file_id}",

View File

@@ -5,7 +5,6 @@ import asyncio
import json
import logging
import sys
import time
from dataclasses import asdict
from dataclasses import dataclass
from pathlib import Path
@@ -28,9 +27,6 @@ INTERNAL_SEARCH_TOOL_NAME = "internal_search"
INTERNAL_SEARCH_IN_CODE_TOOL_ID = "SearchTool"
MAX_REQUEST_ATTEMPTS = 5
RETRIABLE_STATUS_CODES = {429, 500, 502, 503, 504}
QUESTION_TIMEOUT_SECONDS = 300
QUESTION_RETRY_PAUSE_SECONDS = 30
MAX_QUESTION_ATTEMPTS = 3
@dataclass(frozen=True)
@@ -113,27 +109,6 @@ def normalize_api_base(api_base: str) -> str:
return f"{normalized}/api"
def load_completed_question_ids(output_file: Path) -> set[str]:
if not output_file.exists():
return set()
completed_ids: set[str] = set()
with output_file.open("r", encoding="utf-8") as file:
for line in file:
stripped = line.strip()
if not stripped:
continue
try:
record = json.loads(stripped)
except json.JSONDecodeError:
continue
question_id = record.get("question_id")
if isinstance(question_id, str) and question_id:
completed_ids.add(question_id)
return completed_ids
def load_questions(questions_file: Path) -> list[QuestionRecord]:
if not questions_file.exists():
raise FileNotFoundError(f"Questions file not found: {questions_file}")
@@ -373,7 +348,6 @@ async def generate_answers(
api_base: str,
api_key: str,
parallelism: int,
skipped: int,
) -> None:
if parallelism < 1:
raise ValueError("`--parallelism` must be at least 1.")
@@ -408,178 +382,58 @@ async def generate_answers(
write_lock = asyncio.Lock()
completed = 0
successful = 0
stuck_count = 0
failed_questions: list[FailedQuestionRecord] = []
remaining_count = len(questions)
overall_total = remaining_count + skipped
question_durations: list[float] = []
run_start_time = time.monotonic()
def print_progress() -> None:
avg_time = (
sum(question_durations) / len(question_durations)
if question_durations
else 0.0
)
elapsed = time.monotonic() - run_start_time
eta = avg_time * (remaining_count - completed) / max(parallelism, 1)
done = skipped + completed
bar_width = 30
filled = (
int(bar_width * done / overall_total)
if overall_total
else bar_width
)
bar = "" * filled + "" * (bar_width - filled)
pct = (done / overall_total * 100) if overall_total else 100.0
parts = (
f"\r{bar} {pct:5.1f}% "
f"[{done}/{overall_total}] "
f"avg {avg_time:.1f}s/q "
f"elapsed {elapsed:.0f}s "
f"ETA {eta:.0f}s "
f"(ok:{successful} fail:{len(failed_questions)}"
)
if stuck_count:
parts += f" stuck:{stuck_count}"
if skipped:
parts += f" skip:{skipped}"
parts += ")"
sys.stderr.write(parts)
sys.stderr.flush()
print_progress()
total = len(questions)
async def process_question(question_record: QuestionRecord) -> None:
nonlocal completed
nonlocal successful
nonlocal stuck_count
last_error: Exception | None = None
for attempt in range(1, MAX_QUESTION_ATTEMPTS + 1):
q_start = time.monotonic()
try:
async with semaphore:
result = await asyncio.wait_for(
submit_question(
session=session,
api_base=api_base,
headers=headers,
internal_search_tool_id=internal_search_tool_id,
question_record=question_record,
),
timeout=QUESTION_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
async with progress_lock:
stuck_count += 1
logger.warning(
"Question %s timed out after %ss (attempt %s/%s, "
"total stuck: %s) — retrying in %ss",
question_record.question_id,
QUESTION_TIMEOUT_SECONDS,
attempt,
MAX_QUESTION_ATTEMPTS,
stuck_count,
QUESTION_RETRY_PAUSE_SECONDS,
)
print_progress()
last_error = TimeoutError(
f"Timed out after {QUESTION_TIMEOUT_SECONDS}s "
f"on attempt {attempt}/{MAX_QUESTION_ATTEMPTS}"
try:
async with semaphore:
result = await submit_question(
session=session,
api_base=api_base,
headers=headers,
internal_search_tool_id=internal_search_tool_id,
question_record=question_record,
)
await asyncio.sleep(QUESTION_RETRY_PAUSE_SECONDS)
continue
except Exception as exc:
duration = time.monotonic() - q_start
async with progress_lock:
completed += 1
question_durations.append(duration)
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(exc),
)
)
logger.exception(
"Failed question %s (%s/%s)",
question_record.question_id,
completed,
remaining_count,
)
print_progress()
return
duration = time.monotonic() - q_start
async with write_lock:
file.write(json.dumps(asdict(result), ensure_ascii=False))
file.write("\n")
file.flush()
except Exception as exc:
async with progress_lock:
completed += 1
successful += 1
question_durations.append(duration)
print_progress()
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(exc),
)
)
logger.exception(
"Failed question %s (%s/%s)",
question_record.question_id,
completed,
total,
)
return
# All attempts exhausted due to timeouts
async with write_lock:
file.write(json.dumps(asdict(result), ensure_ascii=False))
file.write("\n")
file.flush()
async with progress_lock:
completed += 1
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(last_error),
)
)
logger.error(
"Question %s failed after %s timeout attempts (%s/%s)",
question_record.question_id,
MAX_QUESTION_ATTEMPTS,
completed,
remaining_count,
)
print_progress()
successful += 1
logger.info("Processed %s/%s questions", completed, total)
await asyncio.gather(
*(process_question(question_record) for question_record in questions)
)
# Final newline after progress bar
sys.stderr.write("\n")
sys.stderr.flush()
total_elapsed = time.monotonic() - run_start_time
avg_time = (
sum(question_durations) / len(question_durations)
if question_durations
else 0.0
)
stuck_suffix = f", {stuck_count} stuck timeouts" if stuck_count else ""
resume_suffix = (
f"{skipped} previously completed, "
f"{skipped + successful}/{overall_total} overall"
if skipped
else ""
)
logger.info(
"Done: %s/%s successful in %.1fs (avg %.1fs/question%s)%s",
successful,
remaining_count,
total_elapsed,
avg_time,
stuck_suffix,
resume_suffix,
)
if failed_questions:
logger.warning(
"%s questions failed:",
"Completed with %s failed questions and %s successful questions.",
len(failed_questions),
successful,
)
for failed_question in failed_questions:
logger.warning(
@@ -599,30 +453,7 @@ def main() -> None:
raise ValueError("`--max-questions` must be at least 1 when provided.")
questions = questions[: args.max_questions]
completed_ids = load_completed_question_ids(args.output_file)
logger.info(
"Found %s already-answered question IDs in %s",
len(completed_ids),
args.output_file,
)
total_before_filter = len(questions)
questions = [q for q in questions if q.question_id not in completed_ids]
skipped = total_before_filter - len(questions)
if skipped:
logger.info(
"Resuming: %s/%s already answered, %s remaining",
skipped,
total_before_filter,
len(questions),
)
else:
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
if not questions:
logger.info("All questions already answered. Nothing to do.")
return
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
logger.info("Writing answers to %s", args.output_file)
asyncio.run(
@@ -632,7 +463,6 @@ def main() -> None:
api_base=api_base,
api_key=args.api_key,
parallelism=args.parallelism,
skipped=skipped,
)
)

View File

@@ -27,11 +27,13 @@ def create_placement(
turn_index: int,
tab_index: int = 0,
sub_turn_index: int | None = None,
model_index: int | None = 0,
) -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
model_index=model_index,
)

View File

@@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.enums import AccountType
from onyx.db.models import User
from onyx.db.models import UserRole
from onyx.file_store.file_store import get_default_file_store
@@ -53,12 +52,7 @@ def tenant_context() -> Generator[None, None, None]:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def create_test_user(
db_session: Session,
email_prefix: str,
role: UserRole = UserRole.BASIC,
account_type: AccountType = AccountType.STANDARD,
) -> User:
def create_test_user(db_session: Session, email_prefix: str) -> User:
"""Helper to create a test user with a unique email"""
# Use UUID to ensure unique email addresses
unique_email = f"{email_prefix}_{uuid4().hex[:8]}@example.com"
@@ -74,8 +68,7 @@ def create_test_user(
is_active=True,
is_superuser=False,
is_verified=True,
role=role,
account_type=account_type,
role=UserRole.EXT_PERM_USER,
)
db_session.add(user)
db_session.commit()

View File

@@ -13,29 +13,16 @@ from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import InputType
from onyx.db.enums import AccessType
from onyx.db.enums import AccountType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import PublicExternalUserGroup
from onyx.db.models import User
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.models import UserRole
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
def _create_ext_perm_user(db_session: Session, name: str) -> User:
"""Create an external-permission user for group sync tests."""
return create_test_user(
db_session,
name,
role=UserRole.EXT_PERM_USER,
account_type=AccountType.EXT_PERM_USER,
)
def _create_test_connector_credential_pair(
db_session: Session, source: DocumentSource = DocumentSource.GOOGLE_DRIVE
) -> ConnectorCredentialPair:
@@ -113,9 +100,9 @@ class TestPerformExternalGroupSync:
def test_initial_group_sync(self, db_session: Session) -> None:
"""Test syncing external groups for the first time (initial sync)"""
# Create test data
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
user3 = _create_ext_perm_user(db_session, "user3")
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
user3 = create_test_user(db_session, "user3")
cc_pair = _create_test_connector_credential_pair(db_session)
# Mock external groups data as a generator that yields the expected groups
@@ -188,9 +175,9 @@ class TestPerformExternalGroupSync:
def test_update_existing_groups(self, db_session: Session) -> None:
"""Test updating existing groups (adding/removing users)"""
# Create test data
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
user3 = _create_ext_perm_user(db_session, "user3")
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
user3 = create_test_user(db_session, "user3")
cc_pair = _create_test_connector_credential_pair(db_session)
# Initial sync with original groups
@@ -285,8 +272,8 @@ class TestPerformExternalGroupSync:
def test_remove_groups(self, db_session: Session) -> None:
"""Test removing groups (groups that no longer exist in external system)"""
# Create test data
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
cc_pair = _create_test_connector_credential_pair(db_session)
# Initial sync with multiple groups
@@ -370,7 +357,7 @@ class TestPerformExternalGroupSync:
def test_empty_group_sync(self, db_session: Session) -> None:
"""Test syncing when no groups are returned (all groups removed)"""
# Create test data
user1 = _create_ext_perm_user(db_session, "user1")
user1 = create_test_user(db_session, "user1")
cc_pair = _create_test_connector_credential_pair(db_session)
# Initial sync with groups
@@ -426,7 +413,7 @@ class TestPerformExternalGroupSync:
# Create many test users
users = []
for i in range(150): # More than the batch size of 100
users.append(_create_ext_perm_user(db_session, f"user{i}"))
users.append(create_test_user(db_session, f"user{i}"))
cc_pair = _create_test_connector_credential_pair(db_session)
@@ -465,8 +452,8 @@ class TestPerformExternalGroupSync:
def test_mixed_regular_and_public_groups(self, db_session: Session) -> None:
"""Test syncing a mix of regular and public groups"""
# Create test data
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
cc_pair = _create_test_connector_credential_pair(db_session)
def mixed_group_sync_func(

View File

@@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.enums import AccountType
from onyx.db.enums import BuildSessionStatus
from onyx.db.models import BuildSession
from onyx.db.models import User
@@ -53,7 +52,6 @@ def test_user(db_session: Session, tenant_context: None) -> User: # noqa: ARG00
is_superuser=False,
is_verified=True,
role=UserRole.EXT_PERM_USER,
account_type=AccountType.EXT_PERM_USER,
)
db_session.add(user)
db_session.commit()

View File

@@ -1,51 +0,0 @@
"""
Tests that account_type is correctly set when creating users through
the internal DB functions: add_slack_user_if_not_exists and
batch_add_ext_perm_user_if_not_exists.
These functions are called by background workers (Slack bot, permission sync)
and are not exposed via API endpoints, so they must be tested directly.
"""
from sqlalchemy.orm import Session
from onyx.db.enums import AccountType
from onyx.db.models import UserRole
from onyx.db.users import add_slack_user_if_not_exists
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
def test_slack_user_creation_sets_account_type_bot(db_session: Session) -> None:
"""add_slack_user_if_not_exists sets account_type=BOT and role=SLACK_USER."""
user = add_slack_user_if_not_exists(db_session, "slack_acct_type@test.com")
assert user.role == UserRole.SLACK_USER
assert user.account_type == AccountType.BOT
def test_ext_perm_user_creation_sets_account_type(db_session: Session) -> None:
"""batch_add_ext_perm_user_if_not_exists sets account_type=EXT_PERM_USER."""
users = batch_add_ext_perm_user_if_not_exists(
db_session, ["extperm_acct_type@test.com"]
)
assert len(users) == 1
user = users[0]
assert user.role == UserRole.EXT_PERM_USER
assert user.account_type == AccountType.EXT_PERM_USER
def test_ext_perm_to_slack_upgrade_updates_role_and_account_type(
db_session: Session,
) -> None:
"""When an EXT_PERM_USER is upgraded to slack, both role and account_type update."""
email = "ext_to_slack_acct_type@test.com"
# Create as ext_perm user first
batch_add_ext_perm_user_if_not_exists(db_session, [email])
# Now "upgrade" via slack path
user = add_slack_user_if_not_exists(db_session, email)
assert user.role == UserRole.SLACK_USER
assert user.account_type == AccountType.BOT

View File

@@ -8,7 +8,6 @@ import pytest
from fastapi_users.password import PasswordHelper
from sqlalchemy.orm import Session
from onyx.db.enums import AccountType
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
@@ -47,7 +46,6 @@ def _create_admin(db_session: Session) -> User:
is_superuser=True,
is_verified=True,
role=UserRole.ADMIN,
account_type=AccountType.STANDARD,
)
db_session.add(user)
db_session.commit()

View File

@@ -13,6 +13,7 @@ This test:
All external HTTP calls are mocked, but Postgres and Redis are running.
"""
import queue
from typing import Any
from unittest.mock import patch
from uuid import uuid4
@@ -20,7 +21,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import get_default_emitter
from onyx.chat.emitter import Emitter
from onyx.db.enums import MCPAuthenticationPerformer
from onyx.db.enums import MCPAuthenticationType
from onyx.db.enums import MCPTransport
@@ -137,7 +138,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=search_tool_config,
@@ -200,7 +201,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -275,7 +276,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -350,7 +351,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -458,7 +459,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -541,7 +542,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),

View File

@@ -8,6 +8,7 @@ Tests the priority logic for OAuth tokens when constructing custom tools:
All external HTTP calls are mocked, but Postgres and Redis are running.
"""
import queue
from typing import Any
from unittest.mock import Mock
from unittest.mock import patch
@@ -16,7 +17,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import get_default_emitter
from onyx.chat.emitter import Emitter
from onyx.db.models import OAuthAccount
from onyx.db.models import OAuthConfig
from onyx.db.models import Persona
@@ -174,7 +175,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=search_tool_config,
@@ -232,7 +233,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -284,7 +285,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -345,7 +346,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -416,7 +417,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -483,7 +484,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -536,7 +537,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)

View File

@@ -1175,7 +1175,7 @@ def test_code_interpreter_receives_chat_files(
file_descriptor: FileDescriptor = {
"id": user_file.file_id,
"type": ChatFileType.TABULAR,
"type": ChatFileType.CSV,
"name": "data.csv",
"user_file_id": str(user_file.id),
}

View File

@@ -126,15 +126,6 @@ class UserManager:
return test_user
@staticmethod
def get_permissions(user: DATestUser) -> list[str]:
response = requests.get(
url=f"{API_SERVER_URL}/me/permissions",
headers=user.headers,
)
response.raise_for_status()
return response.json()
@staticmethod
def is_role(
user_to_verify: DATestUser,

View File

@@ -104,30 +104,13 @@ class UserGroupManager:
)
response.raise_for_status()
@staticmethod
def get_permissions(
user_group: DATestUserGroup,
user_performing_action: DATestUser,
) -> list[str]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/permissions",
headers=user_performing_action.headers,
)
response.raise_for_status()
return response.json()
@staticmethod
def get_all(
user_performing_action: DATestUser,
include_default: bool = False,
) -> list[UserGroup]:
params: dict[str, str] = {}
if include_default:
params["include_default"] = "true"
response = requests.get(
f"{API_SERVER_URL}/manage/admin/user-group",
headers=user_performing_action.headers,
params=params,
)
response.raise_for_status()
return [UserGroup(**ug) for ug in response.json()]

View File

@@ -1,13 +1,9 @@
from uuid import UUID
import requests
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestUser
@@ -37,120 +33,3 @@ def test_limited(reset: None) -> None: # noqa: ARG001
headers=api_key.headers,
)
assert response.status_code == 403
def _get_service_account_account_type(
admin_user: DATestUser,
api_key_user_id: UUID,
) -> AccountType:
"""Fetch the account_type of a service account user via the user listing API."""
response = requests.get(
f"{API_SERVER_URL}/manage/users",
headers=admin_user.headers,
params={"include_api_keys": "true"},
)
response.raise_for_status()
data = response.json()
user_id_str = str(api_key_user_id)
for user in data["accepted"]:
if user["id"] == user_id_str:
return AccountType(user["account_type"])
raise AssertionError(
f"Service account user {user_id_str} not found in user listing"
)
def _get_default_group_user_ids(
admin_user: DATestUser,
) -> tuple[set[str], set[str]]:
"""Return (admin_group_user_ids, basic_group_user_ids) from default groups."""
all_groups = UserGroupManager.get_all(
user_performing_action=admin_user,
include_default=True,
)
admin_group = next(
(g for g in all_groups if g.name == "Admin" and g.is_default), None
)
basic_group = next(
(g for g in all_groups if g.name == "Basic" and g.is_default), None
)
assert admin_group is not None, "Admin default group not found"
assert basic_group is not None, "Basic default group not found"
admin_ids = {str(u.id) for u in admin_group.users}
basic_ids = {str(u.id) for u in basic_group.users}
return admin_ids, basic_ids
def test_api_key_limited_service_account(reset: None) -> None: # noqa: ARG001
"""LIMITED role API key: account_type is SERVICE_ACCOUNT, no group membership."""
admin_user: DATestUser = UserManager.create(name="admin_user")
api_key: DATestAPIKey = APIKeyManager.create(
api_key_role=UserRole.LIMITED,
user_performing_action=admin_user,
)
# Verify account_type
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
assert (
account_type == AccountType.SERVICE_ACCOUNT
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
# Verify no group membership
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
user_id_str = str(api_key.user_id)
assert (
user_id_str not in admin_ids
), "LIMITED API key should NOT be in Admin default group"
assert (
user_id_str not in basic_ids
), "LIMITED API key should NOT be in Basic default group"
def test_api_key_basic_service_account(reset: None) -> None: # noqa: ARG001
"""BASIC role API key: account_type is SERVICE_ACCOUNT, in Basic group only."""
admin_user: DATestUser = UserManager.create(name="admin_user")
api_key: DATestAPIKey = APIKeyManager.create(
api_key_role=UserRole.BASIC,
user_performing_action=admin_user,
)
# Verify account_type
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
assert (
account_type == AccountType.SERVICE_ACCOUNT
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
# Verify Basic group membership
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
user_id_str = str(api_key.user_id)
assert user_id_str in basic_ids, "BASIC API key should be in Basic default group"
assert (
user_id_str not in admin_ids
), "BASIC API key should NOT be in Admin default group"
def test_api_key_admin_service_account(reset: None) -> None: # noqa: ARG001
"""ADMIN role API key: account_type is SERVICE_ACCOUNT, in Admin group only."""
admin_user: DATestUser = UserManager.create(name="admin_user")
api_key: DATestAPIKey = APIKeyManager.create(
api_key_role=UserRole.ADMIN,
user_performing_action=admin_user,
)
# Verify account_type
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
assert (
account_type == AccountType.SERVICE_ACCOUNT
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
# Verify Admin group membership
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
user_id_str = str(api_key.user_id)
assert user_id_str in admin_ids, "ADMIN API key should be in Admin default group"
assert (
user_id_str not in basic_ids
), "ADMIN API key should NOT be in Basic default group"

View File

@@ -4,10 +4,8 @@ import pytest
import requests
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
@@ -97,63 +95,3 @@ def test_saml_user_conversion(reset: None) -> None: # noqa: ARG001
# Verify the user's role was changed in the database
assert UserManager.is_role(slack_user, UserRole.BASIC)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="SAML tests are enterprise only",
)
def test_saml_user_conversion_sets_account_type_and_group(
reset: None, # noqa: ARG001
) -> None:
"""
Test that SAML login sets account_type to STANDARD when converting a
non-web user (EXT_PERM_USER) and that the user receives the correct role
(BASIC) after conversion.
This validates the permissions-migration-phase2 changes which ensure that:
1. account_type is updated to 'standard' on SAML conversion
2. The converted user is assigned to the Basic default group
"""
# Create an admin user (first user is automatically admin)
admin_user: DATestUser = UserManager.create(email="admin@example.com")
# Create a user and set them as EXT_PERM_USER
test_email = "ext_convert@example.com"
test_user = UserManager.create(email=test_email)
UserManager.set_role(
user_to_set=test_user,
target_role=UserRole.EXT_PERM_USER,
user_performing_action=admin_user,
explicit_override=True,
)
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
# Simulate SAML login
response = requests.post(
f"{API_SERVER_URL}/manage/users/test-upsert-user",
json={"email": test_email},
headers=admin_user.headers,
)
response.raise_for_status()
user_data = response.json()
# Verify account_type is set to standard after conversion
assert (
user_data["account_type"] == AccountType.STANDARD.value
), f"Expected account_type='{AccountType.STANDARD.value}', got '{user_data['account_type']}'"
# Verify role is BASIC after conversion
assert user_data["role"] == UserRole.BASIC.value
# Verify the user was assigned to the Basic default group
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
basic_default = [g for g in all_groups if g.is_default and g.name == "Basic"]
assert basic_default, "Basic default group not found"
basic_group = basic_default[0]
member_emails = {u.email for u in basic_group.users}
assert test_email in member_emails, (
f"Converted user '{test_email}' not found in Basic default group members: "
f"{member_emails}"
)

View File

@@ -35,16 +35,9 @@ from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import REDIS_DB_NUMBER
from onyx.configs.app_configs import REDIS_HOST
from onyx.configs.app_configs import REDIS_PORT
from onyx.db.enums import AccountType
from onyx.server.settings.models import ApplicationStatus
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.scim_client import ScimClient
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
@@ -218,49 +211,6 @@ def test_create_user(scim_token: str, idp_style: str) -> None:
_assert_entra_emails(body, email)
def test_create_user_default_group_and_account_type(
scim_token: str, idp_style: str
) -> None:
"""SCIM-provisioned users get Basic default group and STANDARD account_type."""
email = f"scim_defaults_{idp_style}@example.com"
ext_id = f"ext-defaults-{idp_style}"
resp = _create_scim_user(scim_token, email, ext_id, idp_style)
assert resp.status_code == 201
user_id = resp.json()["id"]
# --- Verify group assignment via SCIM GET ---
get_resp = ScimClient.get(f"/Users/{user_id}", scim_token)
assert get_resp.status_code == 200
groups = get_resp.json().get("groups", [])
group_names = {g["display"] for g in groups}
assert "Basic" in group_names, f"Expected 'Basic' in groups, got {group_names}"
assert "Admin" not in group_names, "SCIM user should not be in Admin group"
# --- Verify account_type via admin API ---
admin = UserManager.login_as_user(
DATestUser(
id="",
email=build_email(ADMIN_USER_NAME),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.ADMIN,
is_active=True,
)
)
page = UserManager.get_user_page(
user_performing_action=admin,
search_query=email,
)
assert page.total_items >= 1
scim_user_snapshot = next((u for u in page.items if u.email == email), None)
assert (
scim_user_snapshot is not None
), f"SCIM user {email} not found in user listing"
assert (
scim_user_snapshot.account_type == AccountType.STANDARD
), f"Expected STANDARD, got {scim_user_snapshot.account_type}"
def test_get_user(scim_token: str, idp_style: str) -> None:
"""GET /Users/{id} returns the user resource with all stored fields."""
email = f"scim_get_{idp_style}@example.com"

View File

@@ -1,9 +1,3 @@
import mimetypes
from typing import Any
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.file import FileManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
@@ -85,90 +79,3 @@ def test_send_message_with_text_file_attachment(admin_user: DATestUser) -> None:
assert (
"third line" in response.full_message.lower()
), "Chat response should contain the contents of the file"
def _set_token_threshold(admin_user: DATestUser, threshold_k: int) -> None:
"""Set the file token count threshold via admin settings API."""
response = requests.put(
f"{API_SERVER_URL}/admin/settings",
json={"file_token_count_threshold_k": threshold_k},
headers=admin_user.headers,
)
response.raise_for_status()
def _upload_raw(
filename: str,
content: bytes,
user: DATestUser,
) -> dict[str, Any]:
"""Upload a file and return the full JSON response (user_files + rejected_files)."""
mime_type, _ = mimetypes.guess_type(filename)
headers = user.headers.copy()
headers.pop("Content-Type", None)
response = requests.post(
f"{API_SERVER_URL}/user/projects/file/upload",
files=[("files", (filename, content, mime_type or "application/octet-stream"))],
headers=headers,
)
response.raise_for_status()
return response.json()
def test_csv_over_token_threshold_uploaded_not_indexed(
admin_user: DATestUser,
) -> None:
"""CSV exceeding token threshold is uploaded (accepted) but skips indexing."""
_set_token_threshold(admin_user, threshold_k=1)
try:
# ~2000 tokens with default tokenizer, well over 1K threshold
content = ("x " * 100 + "\n") * 20
result = _upload_raw("large.csv", content.encode(), admin_user)
assert len(result["user_files"]) == 1, "CSV should be accepted"
assert len(result["rejected_files"]) == 0, "CSV should not be rejected"
assert (
result["user_files"][0]["status"] == "SKIPPED"
), "CSV over threshold should be SKIPPED (uploaded but not indexed)"
assert (
result["user_files"][0]["chunk_count"] is None
), "Skipped file should have no chunks"
finally:
_set_token_threshold(admin_user, threshold_k=200)
def test_csv_under_token_threshold_uploaded_and_indexed(
admin_user: DATestUser,
) -> None:
"""CSV under token threshold is uploaded and queued for indexing."""
_set_token_threshold(admin_user, threshold_k=200)
try:
content = "col1,col2\na,b\n"
result = _upload_raw("small.csv", content.encode(), admin_user)
assert len(result["user_files"]) == 1, "CSV should be accepted"
assert len(result["rejected_files"]) == 0, "CSV should not be rejected"
assert (
result["user_files"][0]["status"] == "PROCESSING"
), "CSV under threshold should be PROCESSING (queued for indexing)"
finally:
_set_token_threshold(admin_user, threshold_k=200)
def test_txt_over_token_threshold_rejected(
admin_user: DATestUser,
) -> None:
"""Non-exempt file exceeding token threshold is rejected entirely."""
_set_token_threshold(admin_user, threshold_k=1)
try:
# ~2000 tokens, well over 1K threshold. Unlike CSV, .txt is not
# exempt from the threshold so the file should be rejected.
content = ("x " * 100 + "\n") * 20
result = _upload_raw("big.txt", content.encode(), admin_user)
assert len(result["user_files"]) == 0, "File should not be accepted"
assert len(result["rejected_files"]) == 1, "File should be rejected"
assert "token limit" in result["rejected_files"][0]["reason"].lower()
finally:
_set_token_threshold(admin_user, threshold_k=200)

View File

@@ -1,118 +0,0 @@
import os
import pytest
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import Permission
from onyx.db.models import PermissionGrant
from onyx.db.models import UserGroup as UserGroupModel
from onyx.db.permissions import recompute_permissions_for_group__no_commit
from onyx.db.permissions import recompute_user_permissions__no_commit
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="User group tests are enterprise only",
)
def test_user_gets_permissions_when_added_to_group(
reset: None, # noqa: ARG001
) -> None:
admin_user: DATestUser = UserManager.create(name="admin_for_perm_test")
basic_user: DATestUser = UserManager.create(name="basic_user_for_perm_test")
# basic_user starts with only "basic" from the default group
initial_permissions = UserManager.get_permissions(basic_user)
assert "basic" in initial_permissions
assert "add:agents" not in initial_permissions
# Create a new group and add basic_user
group = UserGroupManager.create(
name="perm-test-group",
user_ids=[admin_user.id, basic_user.id],
user_performing_action=admin_user,
)
# Grant a non-basic permission to the group and recompute
with get_session_with_current_tenant() as db_session:
db_group = db_session.get(UserGroupModel, group.id)
assert db_group is not None
db_session.add(
PermissionGrant(
group_id=db_group.id,
permission=Permission.ADD_AGENTS,
grant_source="SYSTEM",
)
)
db_session.flush()
recompute_user_permissions__no_commit(basic_user.id, db_session)
db_session.commit()
# Verify the user gained the new permission (expanded includes read:agents)
updated_permissions = UserManager.get_permissions(basic_user)
assert (
"add:agents" in updated_permissions
), f"User should have 'add:agents' after group grant, got: {updated_permissions}"
assert (
"read:agents" in updated_permissions
), f"User should have implied 'read:agents', got: {updated_permissions}"
assert "basic" in updated_permissions
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="User group tests are enterprise only",
)
def test_group_permission_change_propagates_to_all_members(
reset: None, # noqa: ARG001
) -> None:
admin_user: DATestUser = UserManager.create(name="admin_propagate")
user_a: DATestUser = UserManager.create(name="user_a_propagate")
user_b: DATestUser = UserManager.create(name="user_b_propagate")
group = UserGroupManager.create(
name="propagate-test-group",
user_ids=[admin_user.id, user_a.id, user_b.id],
user_performing_action=admin_user,
)
# Neither user should have add:agents yet
for u in (user_a, user_b):
assert "add:agents" not in UserManager.get_permissions(u)
# Grant add:agents to the group, then batch-recompute
with get_session_with_current_tenant() as db_session:
grant = PermissionGrant(
group_id=group.id,
permission=Permission.ADD_AGENTS,
grant_source="SYSTEM",
)
db_session.add(grant)
db_session.flush()
recompute_permissions_for_group__no_commit(group.id, db_session)
db_session.commit()
# Both users should now have the permission (plus implied read:agents)
for u in (user_a, user_b):
perms = UserManager.get_permissions(u)
assert "add:agents" in perms, f"{u.id} missing add:agents: {perms}"
assert "read:agents" in perms, f"{u.id} missing implied read:agents: {perms}"
# Soft-delete the grant and recompute — permission should be removed
with get_session_with_current_tenant() as db_session:
db_grant = (
db_session.query(PermissionGrant)
.filter_by(group_id=group.id, permission=Permission.ADD_AGENTS)
.first()
)
assert db_grant is not None
db_grant.is_deleted = True
db_session.flush()
recompute_permissions_for_group__no_commit(group.id, db_session)
db_session.commit()
for u in (user_a, user_b):
perms = UserManager.get_permissions(u)
assert "add:agents" not in perms, f"{u.id} still has add:agents: {perms}"

View File

@@ -1,30 +0,0 @@
import os
import pytest
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="User group tests are enterprise only",
)
def test_new_group_gets_basic_permission(reset: None) -> None: # noqa: ARG001
admin_user: DATestUser = UserManager.create(name="admin_for_basic_perm")
user_group = UserGroupManager.create(
name="basic-perm-test-group",
user_ids=[admin_user.id],
user_performing_action=admin_user,
)
permissions = UserGroupManager.get_permissions(
user_group=user_group,
user_performing_action=admin_user,
)
assert (
"basic" in permissions
), f"New group should have 'basic' permission, got: {permissions}"

View File

@@ -1,78 +0,0 @@
"""Integration tests for default group assignment on user registration.
Verifies that:
- The first registered user is assigned to the Admin default group
- Subsequent registered users are assigned to the Basic default group
- account_type is set to STANDARD for email/password registrations
"""
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
def test_default_group_assignment_on_registration(reset: None) -> None: # noqa: ARG001
# Register first user — should become admin
admin_user: DATestUser = UserManager.create(name="first_user")
assert admin_user.role == UserRole.ADMIN
# Register second user — should become basic
basic_user: DATestUser = UserManager.create(name="second_user")
assert basic_user.role == UserRole.BASIC
# Fetch all groups including default ones
all_groups = UserGroupManager.get_all(
user_performing_action=admin_user,
include_default=True,
)
# Find the default Admin and Basic groups
admin_group = next(
(g for g in all_groups if g.name == "Admin" and g.is_default), None
)
basic_group = next(
(g for g in all_groups if g.name == "Basic" and g.is_default), None
)
assert admin_group is not None, "Admin default group not found"
assert basic_group is not None, "Basic default group not found"
# Verify admin user is in Admin group and NOT in Basic group
admin_group_user_ids = {str(u.id) for u in admin_group.users}
basic_group_user_ids = {str(u.id) for u in basic_group.users}
assert (
admin_user.id in admin_group_user_ids
), "First user should be in Admin default group"
assert (
admin_user.id not in basic_group_user_ids
), "First user should NOT be in Basic default group"
# Verify basic user is in Basic group and NOT in Admin group
assert (
basic_user.id in basic_group_user_ids
), "Second user should be in Basic default group"
assert (
basic_user.id not in admin_group_user_ids
), "Second user should NOT be in Admin default group"
# Verify account_type is STANDARD for both users via user listing API
paginated_result = UserManager.get_user_page(
user_performing_action=admin_user,
page_num=0,
page_size=10,
)
users_by_id = {str(u.id): u for u in paginated_result.items}
admin_snapshot = users_by_id.get(admin_user.id)
basic_snapshot = users_by_id.get(basic_user.id)
assert admin_snapshot is not None, "Admin user not found in user listing"
assert basic_snapshot is not None, "Basic user not found in user listing"
assert (
admin_snapshot.account_type == AccountType.STANDARD
), f"Admin user account_type should be STANDARD, got {admin_snapshot.account_type}"
assert (
basic_snapshot.account_type == AccountType.STANDARD
), f"Basic user account_type should be STANDARD, got {basic_snapshot.account_type}"

View File

@@ -1,176 +0,0 @@
"""
Unit tests for onyx.auth.permissions — pure logic and FastAPI dependency.
"""
from unittest.mock import MagicMock
import pytest
from onyx.auth.permissions import ALL_PERMISSIONS
from onyx.auth.permissions import get_effective_permissions
from onyx.auth.permissions import require_permission
from onyx.auth.permissions import resolve_effective_permissions
from onyx.db.enums import Permission
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
# ---------------------------------------------------------------------------
# resolve_effective_permissions
# ---------------------------------------------------------------------------
class TestResolveEffectivePermissions:
def test_empty_set(self) -> None:
assert resolve_effective_permissions(set()) == set()
def test_basic_no_implications(self) -> None:
result = resolve_effective_permissions({"basic"})
assert result == {"basic"}
def test_single_implication(self) -> None:
result = resolve_effective_permissions({"add:agents"})
assert result == {"add:agents", "read:agents"}
def test_manage_agents_implies_add_and_read(self) -> None:
"""manage:agents directly maps to {add:agents, read:agents}."""
result = resolve_effective_permissions({"manage:agents"})
assert result == {"manage:agents", "add:agents", "read:agents"}
def test_manage_connectors_chain(self) -> None:
result = resolve_effective_permissions({"manage:connectors"})
assert result == {"manage:connectors", "add:connectors", "read:connectors"}
def test_manage_document_sets(self) -> None:
result = resolve_effective_permissions({"manage:document_sets"})
assert result == {
"manage:document_sets",
"read:document_sets",
"read:connectors",
}
def test_manage_user_groups_implies_all_reads(self) -> None:
result = resolve_effective_permissions({"manage:user_groups"})
assert result == {
"manage:user_groups",
"read:connectors",
"read:document_sets",
"read:agents",
"read:users",
}
def test_admin_override(self) -> None:
result = resolve_effective_permissions({"admin"})
assert result == set(ALL_PERMISSIONS)
def test_admin_with_others(self) -> None:
result = resolve_effective_permissions({"admin", "basic"})
assert result == set(ALL_PERMISSIONS)
def test_multi_group_union(self) -> None:
result = resolve_effective_permissions(
{"add:agents", "manage:connectors", "basic"}
)
assert result == {
"basic",
"add:agents",
"read:agents",
"manage:connectors",
"add:connectors",
"read:connectors",
}
def test_toggle_permission_no_implications(self) -> None:
result = resolve_effective_permissions({"read:agent_analytics"})
assert result == {"read:agent_analytics"}
def test_all_permissions_for_admin(self) -> None:
result = resolve_effective_permissions({"admin"})
assert len(result) == len(ALL_PERMISSIONS)
# ---------------------------------------------------------------------------
# get_effective_permissions (expands implied at read time)
# ---------------------------------------------------------------------------
class TestGetEffectivePermissions:
def test_expands_implied_permissions(self) -> None:
"""Column stores only granted; get_effective_permissions expands implied."""
user = MagicMock()
user.effective_permissions = ["add:agents"]
result = get_effective_permissions(user)
assert result == {Permission.ADD_AGENTS, Permission.READ_AGENTS}
def test_admin_expands_to_all(self) -> None:
user = MagicMock()
user.effective_permissions = ["admin"]
result = get_effective_permissions(user)
assert result == set(Permission)
def test_basic_stays_basic(self) -> None:
user = MagicMock()
user.effective_permissions = ["basic"]
result = get_effective_permissions(user)
assert result == {Permission.BASIC_ACCESS}
def test_empty_column(self) -> None:
user = MagicMock()
user.effective_permissions = []
result = get_effective_permissions(user)
assert result == set()
# ---------------------------------------------------------------------------
# require_permission (FastAPI dependency)
# ---------------------------------------------------------------------------
class TestRequirePermission:
@pytest.mark.asyncio
async def test_admin_bypass(self) -> None:
"""Admin stored in column should pass any permission check."""
user = MagicMock()
user.effective_permissions = ["admin"]
dep = require_permission(Permission.MANAGE_CONNECTORS)
result = await dep(user=user)
assert result is user
@pytest.mark.asyncio
async def test_has_required_permission(self) -> None:
user = MagicMock()
user.effective_permissions = ["manage:connectors"]
dep = require_permission(Permission.MANAGE_CONNECTORS)
result = await dep(user=user)
assert result is user
@pytest.mark.asyncio
async def test_implied_permission_passes(self) -> None:
"""manage:connectors implies read:connectors at read time."""
user = MagicMock()
user.effective_permissions = ["manage:connectors"]
dep = require_permission(Permission.READ_CONNECTORS)
result = await dep(user=user)
assert result is user
@pytest.mark.asyncio
async def test_missing_permission_raises(self) -> None:
user = MagicMock()
user.effective_permissions = ["basic"]
dep = require_permission(Permission.MANAGE_CONNECTORS)
with pytest.raises(OnyxError) as exc_info:
await dep(user=user)
assert exc_info.value.error_code == OnyxErrorCode.INSUFFICIENT_PERMISSIONS
@pytest.mark.asyncio
async def test_empty_permissions_fails(self) -> None:
user = MagicMock()
user.effective_permissions = []
dep = require_permission(Permission.BASIC_ACCESS)
with pytest.raises(OnyxError):
await dep(user=user)

View File

@@ -1,29 +0,0 @@
"""
Unit tests for UserCreate schema dict methods.
Verifies that account_type is always included in create_update_dict
and create_update_dict_superuser.
"""
from onyx.auth.schemas import UserCreate
from onyx.db.enums import AccountType
def test_create_update_dict_includes_default_account_type() -> None:
uc = UserCreate(email="a@b.com", password="secret123")
d = uc.create_update_dict()
assert d["account_type"] == AccountType.STANDARD
def test_create_update_dict_includes_explicit_account_type() -> None:
uc = UserCreate(
email="a@b.com", password="secret123", account_type=AccountType.SERVICE_ACCOUNT
)
d = uc.create_update_dict()
assert d["account_type"] == AccountType.STANDARD
def test_create_update_dict_superuser_includes_account_type() -> None:
uc = UserCreate(email="a@b.com", password="secret123")
d = uc.create_update_dict_superuser()
assert d["account_type"] == AccountType.STANDARD

View File

@@ -0,0 +1,173 @@
"""Unit tests for the Emitter class.
All tests use the streaming mode (merged_queue required). Emitter has a single
code path — no standalone bus.
"""
import queue
from onyx.chat.emitter import Emitter
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _placement(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
)
def _packet(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Packet:
"""Build a minimal valid packet with an OverallStop payload."""
return Packet(
placement=_placement(turn_index, tab_index, sub_turn_index),
obj=OverallStop(stop_reason="test"),
)
def _make_emitter(model_idx: int = 0) -> tuple["Emitter", "queue.Queue"]:
"""Return (emitter, queue) wired together."""
mq: queue.Queue = queue.Queue()
return Emitter(merged_queue=mq, model_idx=model_idx), mq
# ---------------------------------------------------------------------------
# Queue routing
# ---------------------------------------------------------------------------
class TestEmitterQueueRouting:
def test_emit_lands_on_merged_queue(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet())
assert not mq.empty()
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
emitter, mq = _make_emitter(model_idx=1)
emitter.emit(_packet())
item = mq.get_nowait()
assert isinstance(item, tuple)
assert len(item) == 2
def test_multiple_packets_delivered_fifo(self) -> None:
emitter, mq = _make_emitter()
p1 = _packet(turn_index=0)
p2 = _packet(turn_index=1)
emitter.emit(p1)
emitter.emit(p2)
_, t1 = mq.get_nowait()
_, t2 = mq.get_nowait()
assert t1.placement.turn_index == 0
assert t2.placement.turn_index == 1
# ---------------------------------------------------------------------------
# model_index tagging
# ---------------------------------------------------------------------------
class TestEmitterModelIndexTagging:
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
"""N=1: default model_idx=0, so packet gets model_index=0."""
emitter, mq = _make_emitter(model_idx=0)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 0
def test_model_idx_one_tags_packet(self) -> None:
emitter, mq = _make_emitter(model_idx=1)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 1
def test_model_idx_two_tags_packet(self) -> None:
"""Boundary: third model in a 3-model run."""
emitter, mq = _make_emitter(model_idx=2)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 2
# ---------------------------------------------------------------------------
# Queue key
# ---------------------------------------------------------------------------
class TestEmitterQueueKey:
def test_key_equals_model_idx(self) -> None:
"""Drain loop uses the key to route packets; it must match model_idx."""
emitter, mq = _make_emitter(model_idx=2)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 2
def test_n1_key_is_zero(self) -> None:
emitter, mq = _make_emitter(model_idx=0)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 0
# ---------------------------------------------------------------------------
# Placement field preservation
# ---------------------------------------------------------------------------
class TestEmitterPlacementPreservation:
def test_turn_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(turn_index=5))
_, tagged = mq.get_nowait()
assert tagged.placement.turn_index == 5
def test_tab_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(tab_index=3))
_, tagged = mq.get_nowait()
assert tagged.placement.tab_index == 3
def test_sub_turn_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(sub_turn_index=2))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index == 2
def test_sub_turn_index_none_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(sub_turn_index=None))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index is None
def test_packet_obj_is_not_modified(self) -> None:
"""The payload object must survive tagging untouched."""
emitter, mq = _make_emitter()
original_obj = OverallStop(stop_reason="sentinel")
pkt = Packet(placement=_placement(), obj=original_obj)
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert tagged.obj is original_obj
def test_different_obj_types_are_handled(self) -> None:
"""Any valid PacketObj type passes through correctly."""
emitter, mq = _make_emitter()
pkt = Packet(placement=_placement(), obj=ReasoningStart())
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert isinstance(tagged.obj, ReasoningStart)

Some files were not shown because too many files have changed in this diff Show More