Compare commits

..

13 Commits

Author SHA1 Message Date
Nik
058f3d1403 fix(chat): restore use_metadata_only() handling in extract_context_files
Restores the metadata-only file path that was regressed when
process_message.py was copied from a branch that pre-dated this logic.

Changes:
- Restore mime_type_to_chat_file_type import and build_file_context import
- Restore use_metadata_only() filter in aggregate_tokens calculation
- Restore tool_metadata accumulation in the "files fit" loop
- Restore file_metadata_for_tool in the "files fit" return path
- Restore _build_tool_metadata helper (single-file version)
2026-04-01 09:40:48 -07:00
Nik
6255a299b1 feat(chat): add multi-model parallel streaming (N=2-3 LLMs side-by-side)
- _run_models: ThreadPoolExecutor drain loop, arrival-order queue, GeneratorExit
- handle_multi_model_stream: validation, setup, per-model Emitter(model_idx=i)
- run_multi_model_stream: public entrypoint for N=2-3 model comparison
- set_preferred_response: DB helper for user to pick preferred model response
- reserve_multi_model_message_ids: reserve N ChatMessage rows atomically
- Fixes B1 (self-completion race), B2 (stop-button errored models), P1 (orphaned rows)
- 26 unit tests covering validation, drain loop, emitter routing, DB helpers
2026-04-01 08:38:30 -07:00
Nik
caa8811d61 style(chat): apply black formatting to process_message.py 2026-04-01 08:36:21 -07:00
Nik
7208d7ba8d fix(chat): remove bounded queue and packet drops — match old behavior
Old code used queue.Queue() (unbounded, blocking put). New code introduced
queue.Queue(maxsize=100) + put(timeout=3.0) + silent drop on queue.Full —
a regression in all three callsites:

- Emitter.emit(): data packets silently dropped on queue full
- _run_model exception path: model errors silently lost
- _run_model finally (_MODEL_DONE): if dropped, drain loop hangs forever
  (models_remaining never reaches 0)

Fix: remove maxsize, remove all timeout= arguments, remove all
except queue.Full handlers. The drain_done early-return in emit() is the
correct disconnect mechanism; queue backpressure is not needed.

Also adds _completion_done: bool type annotation and fixes the queue drain
comment (no longer unblocking timed-out puts — just releasing memory).
2026-04-01 08:25:53 -07:00
Nik
23058c416d fix(chat): use model_succeeded instead of check_is_connected on self-completion
On HTTP disconnect, check_is_connected() returns False, causing
llm_loop_completion_handle to treat a completed response as
user-cancelled and append "Generation was stopped by the user."
Use lambda: model_succeeded[model_idx] (always True here) instead,
matching the cancellation path's functools.partial(bool, model_succeeded[i]).
2026-04-01 08:25:04 -07:00
Nik
0874e0a5e6 fix(chat): persist LLM response on HTTP disconnect via drain_done + worker self-completion
When the HTTP client disconnects, Starlette throws GeneratorExit into the
drain loop generator. The old else branch just called executor.shutdown(wait=False)
with no completion handling, leaving the assistant DB message as the TERMINATED
placeholder forever (regressing test_send_message_disconnect_and_cleanup).

New design:
- drain_done (threading.Event) signals emitters to return immediately instead
  of blocking on queue.put — no retry loops, no daemon threads
- One-time queue drain in the else block releases any in-progress puts so
  workers exit within milliseconds
- Workers self-complete: after run_llm_loop returns, each worker checks
  drain_done.is_set() and, if true, opens its own DB session and calls
  llm_loop_completion_handle directly
2026-04-01 08:25:04 -07:00
Nik
165237faf4 fix(emitter): address Greptile P1/P2/P3 and Queue typing
- P1: executor.shutdown(wait=False) on early exit — don't block the
  server thread waiting for LLM workers; they will hit queue.Full
  timeouts and exit on their own (matches old run_chat_loop behavior)
- P2: wrap db_session.commit() in try/finally in build_chat_turn —
  reset processing status before propagating if commit fails, so the
  chat session isn't stuck at "processing" permanently
- P3: fix inaccurate comment "All worker threads have exited" — workers
  may still be closing their own DB sessions at that point; clarify
  that only the main-thread db_session is safe to use
- Queue[Any] → Queue[tuple[int, Packet | Exception | object]] in Emitter
2026-04-01 08:25:04 -07:00
Nik
19c3122fec fix(review): address Greptile comments
- Add owner to bare TODO comment
- Restore placement field assertions weakened by Emitter refactor
2026-04-01 08:23:42 -07:00
Nik
00b228b357 refactor(chat): elegance pass on PR1 changed files
process_message.py:
- Fix `skip_clarification` field in ChatTurnSetup: inline comment inside
  the type annotation → separate `#` comment on the line above the field
- Flatten `model_tools` via list comprehension instead of manual extend loop
- `forced_tool_id` membership test: list → set comprehension (O(1) lookup)
- Trim `_run_model` inner-function docstring — private closure doesn't need
  10-line Args block
- Remove redundant inline param comments from `_stream_chat_turn` and
  `handle_stream_message_objects` where the docstring Args section already
  documents them
- Strip duplicate Args/Returns from `handle_stream_message_objects` docstring
  — it delegates entirely to `_stream_chat_turn`

emitter.py:
- Widen `merged_queue` annotation to `Queue[Any]`: Queue is invariant so
  `Queue[tuple[int, Packet]]` can't be passed a `Queue[tuple[int, Packet |
  Exception | object]]`; the emitter is a write-only producer and doesn't
  care what else lives on the queue
2026-04-01 08:23:42 -07:00
Nik
59ae32f764 refactor(emitter): clean up string annotation and use model_copy
- Fix `"Queue"` forward-reference annotation → `Queue[tuple[int, Packet]]`
  (Queue is already imported, the string was unnecessary)
- Replace manual Placement field copy with `base.model_copy(update={...})`
- Remove redundant `key` variable (was just `self._model_idx`)
- Tighten docstring
2026-04-01 08:22:51 -07:00
Nik
76000330ad refactor(chat): replace bus-polling emitter with merged-queue streaming; fix 429 hang
Switch Emitter from a per-model event bus + polling thread to a single
bounded queue shared across all models.  Each emit() call puts directly onto
the queue; the drain loop in _run_models yields packets in arrival order.

Key changes
- emitter.py: remove Bus, get_default_emitter(); add Emitter(merged_queue, model_idx)
- chat_state.py: remove run_chat_loop_with_state_containers (113-line bus-poll loop)
- process_message.py: add ChatTurnSetup dataclass and build_chat_turn(); rewrite
  _stream_chat_turn + _run_models around the merged queue; single-model (N=1)
  path is fully backwards-compatible
- placement.py, override_models.py: add docstrings; LLMOverride gains display_name
- research_agent.py, custom_tool.py: update Emitter call sites
- test_emitter.py: new unit tests for queue routing, model_index tagging, placement

Frontend 429 fix
- lib.tsx: parse response body for human-readable detail on non-2xx responses
  instead of "HTTP error! status: 429"
- useChatController.ts: surface stack.error after the FIFO drain loop exits so
  the catch block replaces the thinking placeholder with an error message
2026-04-01 08:22:51 -07:00
Raunak Bhagat
eb6bd42c1e refactor(admin): revamp Service Accounts page and AdminListHeader (#9824) 2026-04-01 15:11:01 +00:00
Danelegend
953cc28625 feat(files): Inject file metadata over content for certain files (#9786) 2026-04-01 13:19:11 +00:00
83 changed files with 3575 additions and 3364 deletions

View File

@@ -1,108 +0,0 @@
"""backfill_account_type
Revision ID: 03d085c5c38d
Revises: 977e834c1427
Create Date: 2026-03-25 16:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "03d085c5c38d"
down_revision = "977e834c1427"
branch_labels = None
depends_on = None
_STANDARD = "STANDARD"
_BOT = "BOT"
_EXT_PERM_USER = "EXT_PERM_USER"
_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
_ANONYMOUS = "ANONYMOUS"
# Well-known anonymous user UUID
ANONYMOUS_USER_ID = "00000000-0000-0000-0000-000000000002"
# Email pattern for API key virtual users
API_KEY_EMAIL_PATTERN = r"API\_KEY\_\_%"
# Reflect the table structure for use in DML
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("email", sa.String),
sa.column("role", sa.String),
sa.column("account_type", sa.String),
)
def upgrade() -> None:
# ------------------------------------------------------------------
# Step 1: Backfill account_type from role.
# Order matters — most-specific matches first so the final catch-all
# only touches rows that haven't been classified yet.
# ------------------------------------------------------------------
# 1a. API key virtual users → SERVICE_ACCOUNT
op.execute(
sa.update(user_table)
.where(
user_table.c.email.ilike(API_KEY_EMAIL_PATTERN),
user_table.c.account_type.is_(None),
)
.values(account_type=_SERVICE_ACCOUNT)
)
# 1b. Anonymous user → ANONYMOUS
op.execute(
sa.update(user_table)
.where(
user_table.c.id == ANONYMOUS_USER_ID,
user_table.c.account_type.is_(None),
)
.values(account_type=_ANONYMOUS)
)
# 1c. SLACK_USER role → BOT
op.execute(
sa.update(user_table)
.where(
user_table.c.role == "SLACK_USER",
user_table.c.account_type.is_(None),
)
.values(account_type=_BOT)
)
# 1d. EXT_PERM_USER role → EXT_PERM_USER
op.execute(
sa.update(user_table)
.where(
user_table.c.role == "EXT_PERM_USER",
user_table.c.account_type.is_(None),
)
.values(account_type=_EXT_PERM_USER)
)
# 1e. Everything else → STANDARD
op.execute(
sa.update(user_table)
.where(user_table.c.account_type.is_(None))
.values(account_type=_STANDARD)
)
# ------------------------------------------------------------------
# Step 2: Set account_type to NOT NULL now that every row is filled.
# ------------------------------------------------------------------
op.alter_column(
"user",
"account_type",
nullable=False,
server_default="STANDARD",
)
def downgrade() -> None:
op.alter_column("user", "account_type", nullable=True, server_default=None)
op.execute(sa.update(user_table).values(account_type=None))

View File

@@ -1,104 +0,0 @@
"""add_effective_permissions
Adds a JSONB column `effective_permissions` to the user table to store
directly granted permissions (e.g. ["admin"] or ["basic"]). Implied
permissions are expanded at read time, not stored.
Backfill: joins user__user_group → permission_grant to collect each
user's granted permissions into a JSON array. Users without group
memberships keep the default [].
Revision ID: 503883791c39
Revises: b4b7e1028dfd
Create Date: 2026-03-30 14:49:22.261748
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "503883791c39"
down_revision = "b4b7e1028dfd"
branch_labels: str | None = None
depends_on: str | Sequence[str] | None = None
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("effective_permissions", postgresql.JSONB),
)
user_user_group = sa.table(
"user__user_group",
sa.column("user_id", sa.Uuid),
sa.column("user_group_id", sa.Integer),
)
permission_grant = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("is_deleted", sa.Boolean),
)
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"effective_permissions",
postgresql.JSONB(),
nullable=False,
server_default=sa.text("'[]'::jsonb"),
),
)
conn = op.get_bind()
# Deduplicated permissions per user
deduped = (
sa.select(
user_user_group.c.user_id,
permission_grant.c.permission,
)
.select_from(
user_user_group.join(
permission_grant,
sa.and_(
permission_grant.c.group_id == user_user_group.c.user_group_id,
permission_grant.c.is_deleted == sa.false(),
),
)
)
.distinct()
.subquery("deduped")
)
# Aggregate into JSONB array per user (order is not guaranteed;
# consumers read this as a set so ordering does not matter)
perms_per_user = (
sa.select(
deduped.c.user_id,
sa.func.jsonb_agg(
deduped.c.permission,
type_=postgresql.JSONB,
).label("perms"),
)
.group_by(deduped.c.user_id)
.subquery("sub")
)
conn.execute(
user_table.update()
.where(user_table.c.id == perms_per_user.c.user_id)
.values(effective_permissions=perms_per_user.c.perms)
)
def downgrade() -> None:
op.drop_column("user", "effective_permissions")

View File

@@ -1,136 +0,0 @@
"""seed_default_groups
Revision ID: 977e834c1427
Revises: 8188861f4e92
Create Date: 2026-03-25 14:59:41.313091
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import insert as pg_insert
# revision identifiers, used by Alembic.
revision = "977e834c1427"
down_revision = "8188861f4e92"
branch_labels = None
depends_on = None
# (group_name, permission_value)
DEFAULT_GROUPS = [
("Admin", "admin"),
("Basic", "basic"),
]
CUSTOM_SUFFIX = "(Custom)"
MAX_RENAME_ATTEMPTS = 100
# Reflect table structures for use in DML
user_group_table = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("name", sa.String),
sa.column("is_up_to_date", sa.Boolean),
sa.column("is_up_for_deletion", sa.Boolean),
sa.column("is_default", sa.Boolean),
)
permission_grant_table = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("grant_source", sa.String),
)
user__user_group_table = sa.table(
"user__user_group",
sa.column("user_group_id", sa.Integer),
sa.column("user_id", sa.Uuid),
)
def _find_available_name(conn: sa.engine.Connection, base: str) -> str:
"""Return a name like 'Admin (Custom)' or 'Admin (Custom 2)' that is not taken."""
candidate = f"{base} {CUSTOM_SUFFIX}"
attempt = 1
while attempt <= MAX_RENAME_ATTEMPTS:
exists = conn.execute(
sa.select(sa.literal(1))
.select_from(user_group_table)
.where(user_group_table.c.name == candidate)
.limit(1)
).fetchone()
if exists is None:
return candidate
attempt += 1
candidate = f"{base} (Custom {attempt})"
raise RuntimeError(
f"Could not find an available name for group '{base}' "
f"after {MAX_RENAME_ATTEMPTS} attempts"
)
def upgrade() -> None:
conn = op.get_bind()
for group_name, permission_value in DEFAULT_GROUPS:
# Step 1: Rename ALL existing groups that clash with the canonical name.
conflicting = conn.execute(
sa.select(user_group_table.c.id, user_group_table.c.name).where(
user_group_table.c.name == group_name
)
).fetchall()
for row_id, row_name in conflicting:
new_name = _find_available_name(conn, row_name)
op.execute(
sa.update(user_group_table)
.where(user_group_table.c.id == row_id)
.values(name=new_name, is_up_to_date=False)
)
# Step 2: Create a fresh default group.
result = conn.execute(
user_group_table.insert()
.values(
name=group_name,
is_up_to_date=True,
is_up_for_deletion=False,
is_default=True,
)
.returning(user_group_table.c.id)
).fetchone()
assert result is not None
group_id = result[0]
# Step 3: Upsert permission grant.
op.execute(
pg_insert(permission_grant_table)
.values(
group_id=group_id,
permission=permission_value,
grant_source="SYSTEM",
)
.on_conflict_do_nothing(index_elements=["group_id", "permission"])
)
def downgrade() -> None:
# Remove the default groups created by this migration.
# First remove user-group memberships that reference default groups
# to avoid FK violations, then delete the groups themselves.
default_group_ids = sa.select(user_group_table.c.id).where(
user_group_table.c.is_default == True # noqa: E712
)
op.execute(
sa.delete(user__user_group_table).where(
user__user_group_table.c.user_group_id.in_(default_group_ids)
)
)
op.execute(
sa.delete(user_group_table).where(
user_group_table.c.is_default == True # noqa: E712
)
)

View File

@@ -1,84 +0,0 @@
"""grant_basic_to_existing_groups
Grants the "basic" permission to all existing groups that don't already
have it. Every group should have at least "basic" so that its members
get basic access when effective_permissions is backfilled.
Revision ID: b4b7e1028dfd
Revises: b7bcc991d722
Create Date: 2026-03-30 16:15:17.093498
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b4b7e1028dfd"
down_revision = "b7bcc991d722"
branch_labels: str | None = None
depends_on: str | Sequence[str] | None = None
user_group = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("is_default", sa.Boolean),
)
permission_grant = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("grant_source", sa.String),
sa.column("is_deleted", sa.Boolean),
)
def upgrade() -> None:
conn = op.get_bind()
already_has_basic = (
sa.select(sa.literal(1))
.select_from(permission_grant)
.where(
permission_grant.c.group_id == user_group.c.id,
permission_grant.c.permission == "basic",
)
.exists()
)
groups_needing_basic = sa.select(
user_group.c.id,
sa.literal("basic").label("permission"),
sa.literal("SYSTEM").label("grant_source"),
sa.literal(False).label("is_deleted"),
).where(
user_group.c.is_default == sa.false(),
~already_has_basic,
)
conn.execute(
permission_grant.insert().from_select(
["group_id", "permission", "grant_source", "is_deleted"],
groups_needing_basic,
)
)
def downgrade() -> None:
conn = op.get_bind()
non_default_group_ids = sa.select(user_group.c.id).where(
user_group.c.is_default == sa.false()
)
conn.execute(
permission_grant.delete().where(
permission_grant.c.permission == "basic",
permission_grant.c.grant_source == "SYSTEM",
permission_grant.c.group_id.in_(non_default_group_ids),
)
)

View File

@@ -1,116 +0,0 @@
"""assign_users_to_default_groups
Revision ID: b7bcc991d722
Revises: 03d085c5c38d
Create Date: 2026-03-25 16:30:39.529301
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import insert as pg_insert
# revision identifiers, used by Alembic.
revision = "b7bcc991d722"
down_revision = "03d085c5c38d"
branch_labels = None
depends_on = None
# Reflect table structures for use in DML
user_group_table = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("name", sa.String),
sa.column("is_default", sa.Boolean),
)
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("role", sa.String),
sa.column("account_type", sa.String),
sa.column("is_active", sa.Boolean),
)
user__user_group_table = sa.table(
"user__user_group",
sa.column("user_group_id", sa.Integer),
sa.column("user_id", sa.Uuid),
)
def upgrade() -> None:
conn = op.get_bind()
# Look up default group IDs
admin_row = conn.execute(
sa.select(user_group_table.c.id).where(
user_group_table.c.name == "Admin",
user_group_table.c.is_default == True, # noqa: E712
)
).fetchone()
basic_row = conn.execute(
sa.select(user_group_table.c.id).where(
user_group_table.c.name == "Basic",
user_group_table.c.is_default == True, # noqa: E712
)
).fetchone()
if admin_row is None:
raise RuntimeError(
"Default 'Admin' group not found. "
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
)
if basic_row is None:
raise RuntimeError(
"Default 'Basic' group not found. "
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
)
# Users with role=admin → Admin group
# Exclude inactive placeholder/anonymous users that are not real users
admin_users = sa.select(
sa.literal(admin_row[0]).label("user_group_id"),
user_table.c.id.label("user_id"),
).where(
user_table.c.role == "ADMIN",
user_table.c.is_active == True, # noqa: E712
)
op.execute(
pg_insert(user__user_group_table)
.from_select(["user_group_id", "user_id"], admin_users)
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
)
# STANDARD users (non-admin) and SERVICE_ACCOUNT users (role=basic) → Basic group
# Exclude inactive placeholder/anonymous users that are not real users
basic_users = sa.select(
sa.literal(basic_row[0]).label("user_group_id"),
user_table.c.id.label("user_id"),
).where(
user_table.c.is_active == True, # noqa: E712
sa.or_(
sa.and_(
user_table.c.account_type == "STANDARD",
user_table.c.role != "ADMIN",
),
sa.and_(
user_table.c.account_type == "SERVICE_ACCOUNT",
user_table.c.role == "BASIC",
),
),
)
op.execute(
pg_insert(user__user_group_table)
.from_select(["user_group_id", "user_id"], basic_users)
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
)
def downgrade() -> None:
# Group memberships are left in place — removing them risks
# deleting memberships that existed before this migration.
pass

View File

@@ -19,8 +19,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import GrantSource
from onyx.db.enums import Permission
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
@@ -30,7 +28,6 @@ from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import PermissionGrant
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
@@ -39,7 +36,6 @@ from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.models import UserRole
from onyx.db.permissions import recompute_user_permissions__no_commit
from onyx.db.users import fetch_user_by_id
from onyx.utils.logger import setup_logger
@@ -259,7 +255,6 @@ def fetch_user_groups(
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
include_default: bool = True,
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -274,7 +269,6 @@ def fetch_user_groups(
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
include_default: If False, excludes system default groups (is_default=True).
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -282,8 +276,6 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
@@ -294,7 +286,6 @@ def fetch_user_groups_for_user(
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
include_default: bool = True,
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -304,8 +295,6 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
@@ -489,16 +478,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
db_session.add(db_user_group)
db_session.flush() # give the group an ID
# Every group gets the "basic" permission by default
db_session.add(
PermissionGrant(
group_id=db_user_group.id,
permission=Permission.BASIC_ACCESS,
grant_source=GrantSource.SYSTEM,
)
)
db_session.flush()
_add_user__user_group_relationships__no_commit(
db_session=db_session,
user_group_id=db_user_group.id,
@@ -510,9 +489,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
cc_pair_ids=user_group.cc_pair_ids,
)
for uid in user_group.user_ids:
recompute_user_permissions__no_commit(uid, db_session)
db_session.commit()
return db_user_group
@@ -820,9 +796,6 @@ def update_user_group(
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()
for uid in set(added_user_ids) | set(removed_user_ids):
recompute_user_permissions__no_commit(uid, db_session)
db_session.commit()
return db_user_group
@@ -862,17 +835,6 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
_check_user_group_is_modifiable(db_user_group)
# Collect affected user IDs before cleanup deletes the relationships
affected_user_ids = (
db_session.execute(
select(User__UserGroup.user_id).where(
User__UserGroup.user_group_id == user_group_id
)
)
.scalars()
.all()
)
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
@@ -901,11 +863,6 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
db_session=db_session, user_group_id=user_group_id
)
# Recompute permissions for affected users now that their
# membership in this group has been removed
for uid in affected_user_ids:
recompute_user_permissions__no_commit(uid, db_session)
db_user_group.is_up_to_date = False
db_user_group.is_up_for_deletion = True
db_session.commit()

View File

@@ -52,13 +52,11 @@ from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccountType
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -488,7 +486,6 @@ def create_user(
email=email,
hashed_password=_pw_helper.hash(_pw_helper.generate()),
role=UserRole.BASIC,
account_type=AccountType.STANDARD,
is_active=user_resource.active,
is_verified=True,
personal_name=personal_name,
@@ -509,25 +506,13 @@ def create_user(
scim_username=scim_username,
fields=fields,
)
dal.commit()
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"User with email {email} already has a SCIM mapping"
)
# Assign user to default group BEFORE commit so everything is atomic.
# If this fails, the entire user creation rolls back and IdP can retry.
try:
assign_user_to_default_groups__no_commit(db_session, user)
except Exception:
dal.rollback()
logger.exception(f"Failed to assign SCIM user {email} to default groups")
return _scim_error_response(
500, f"Failed to assign user {email} to default group"
)
dal.commit()
return _scim_resource_response(
provider.build_user_resource(
user,

View File

@@ -43,16 +43,12 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
@router.get("/admin/user-group")
def list_user_groups(
include_default: bool = False,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session,
only_up_to_date=False,
eager_load_for_snapshot=True,
include_default=include_default,
db_session, only_up_to_date=False, eager_load_for_snapshot=True
)
else:
user_groups = fetch_user_groups_for_user(
@@ -60,50 +56,27 @@ def list_user_groups(
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
include_default=include_default,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]
@router.get("/user-groups/minimal")
def list_minimal_user_groups(
include_default: bool = False,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[MinimalUserGroupSnapshot]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session,
only_up_to_date=False,
include_default=include_default,
)
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
include_default=include_default,
)
return [
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
]
@router.get("/admin/user-group/{user_group_id}/permissions")
def get_user_group_permissions(
user_group_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[str]:
group = fetch_user_group(db_session, user_group_id)
if group is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
return [
grant.permission.value
for grant in group.permission_grants
if not grant.is_deleted
]
@router.post("/admin/user-group")
def create_user_group(
user_group: UserGroupCreate,
@@ -127,9 +100,6 @@ def rename_user_group_endpoint(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
group = fetch_user_group(db_session, rename_request.id)
if group and group.is_default:
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot rename a default system group.")
try:
return UserGroup.from_model(
rename_user_group(
@@ -215,9 +185,6 @@ def delete_user_group(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
group = fetch_user_group(db_session, user_group_id)
if group and group.is_default:
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot delete a default system group.")
try:
prepare_user_group_for_deletion(db_session, user_group_id)
except ValueError as e:

View File

@@ -22,7 +22,6 @@ class UserGroup(BaseModel):
personas: list[PersonaSnapshot]
is_up_to_date: bool
is_up_for_deletion: bool
is_default: bool
@classmethod
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
@@ -75,21 +74,18 @@ class UserGroup(BaseModel):
],
is_up_to_date=user_group_model.is_up_to_date,
is_up_for_deletion=user_group_model.is_up_for_deletion,
is_default=user_group_model.is_default,
)
class MinimalUserGroupSnapshot(BaseModel):
id: int
name: str
is_default: bool
@classmethod
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
return cls(
id=user_group_model.id,
name=user_group_model.name,
is_default=user_group_model.is_default,
)

View File

@@ -1,110 +0,0 @@
"""
Permission resolution for group-based authorization.
Granted permissions are stored as a JSONB column on the User table and
loaded for free with every auth query. Implied permissions are expanded
at read time — only directly granted permissions are persisted.
"""
from collections.abc import Callable
from collections.abc import Coroutine
from typing import Any
from fastapi import Depends
from onyx.auth.users import current_user
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
logger = setup_logger()
ALL_PERMISSIONS: frozenset[str] = frozenset(p.value for p in Permission)
# Implication map: granted permission -> set of permissions it implies.
IMPLIED_PERMISSIONS: dict[str, set[str]] = {
Permission.ADD_AGENTS.value: {Permission.READ_AGENTS.value},
Permission.MANAGE_AGENTS.value: {
Permission.ADD_AGENTS.value,
Permission.READ_AGENTS.value,
},
Permission.MANAGE_DOCUMENT_SETS.value: {
Permission.READ_DOCUMENT_SETS.value,
Permission.READ_CONNECTORS.value,
},
Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value},
Permission.MANAGE_CONNECTORS.value: {
Permission.ADD_CONNECTORS.value,
Permission.READ_CONNECTORS.value,
},
Permission.MANAGE_USER_GROUPS.value: {
Permission.READ_CONNECTORS.value,
Permission.READ_DOCUMENT_SETS.value,
Permission.READ_AGENTS.value,
Permission.READ_USERS.value,
},
}
def resolve_effective_permissions(granted: set[str]) -> set[str]:
"""Expand granted permissions with their implied permissions.
If "admin" is present, returns all 19 permissions.
"""
if Permission.FULL_ADMIN_PANEL_ACCESS.value in granted:
return set(ALL_PERMISSIONS)
effective = set(granted)
changed = True
while changed:
changed = False
for perm in list(effective):
implied = IMPLIED_PERMISSIONS.get(perm)
if implied and not implied.issubset(effective):
effective |= implied
changed = True
return effective
def get_effective_permissions(user: User) -> set[Permission]:
"""Read granted permissions from the column and expand implied permissions."""
granted: set[Permission] = set()
for p in user.effective_permissions:
try:
granted.add(Permission(p))
except ValueError:
logger.warning(f"Skipping unknown permission '{p}' for user {user.id}")
if Permission.FULL_ADMIN_PANEL_ACCESS in granted:
return set(Permission)
expanded = resolve_effective_permissions({p.value for p in granted})
return {Permission(p) for p in expanded}
def require_permission(
required: Permission,
) -> Callable[..., Coroutine[Any, Any, User]]:
"""FastAPI dependency factory for permission-based access control.
Usage:
@router.get("/endpoint")
def endpoint(user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS))):
...
"""
async def dependency(user: User = Depends(current_user)) -> User:
effective = get_effective_permissions(user)
if Permission.FULL_ADMIN_PANEL_ACCESS in effective:
return user
if required not in effective:
raise OnyxError(
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
"You do not have the required permissions for this action.",
)
return user
return dependency

View File

@@ -5,8 +5,6 @@ from typing import Any
from fastapi_users import schemas
from typing_extensions import override
from onyx.db.enums import AccountType
class UserRole(str, Enum):
"""
@@ -43,7 +41,6 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
account_type: AccountType = AccountType.STANDARD
tenant_id: str | None = None
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
# Excluded from create_update_dict so it never reaches the DB layer
@@ -53,16 +50,12 @@ class UserCreate(schemas.BaseUserCreate):
def create_update_dict(self) -> dict[str, Any]:
d = super().create_update_dict()
d.pop("captcha_token", None)
# Force STANDARD for self-registration; only trusted paths
# (SCIM, API key creation) supply a different account_type directly.
d["account_type"] = AccountType.STANDARD
return d
@override
def create_update_dict_superuser(self) -> dict[str, Any]:
d = super().create_update_dict_superuser()
d.pop("captcha_token", None)
d.setdefault("account_type", self.account_type)
return d

View File

@@ -120,13 +120,11 @@ from onyx.db.engine.async_sql_engine import get_async_session
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.enums import AccountType
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.db.users import get_user_by_email
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import log_onyx_error
@@ -696,7 +694,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
"email": account_email,
"hashed_password": self.password_helper.hash(password),
"is_verified": is_verified_by_default,
"account_type": AccountType.STANDARD,
}
user = await self.user_db.create(user_dict)
@@ -746,23 +743,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
with get_session_with_current_tenant() as sync_db:
enforce_seat_limit(sync_db)
# Upgrade the user and assign default groups in a single
# transaction so neither change is visible without the other.
was_inactive = not user.is_active
with get_session_with_current_tenant() as sync_db:
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
if sync_user:
sync_user.is_verified = is_verified_by_default
sync_user.role = UserRole.BASIC
sync_user.account_type = AccountType.STANDARD
if was_inactive:
sync_user.is_active = True
assign_user_to_default_groups__no_commit(sync_db, sync_user)
sync_db.commit()
# Refresh the async user object so downstream code
# (e.g. oidc_expiry check) sees the updated fields.
user = await self.user_db.get(user.id) # type: ignore[arg-type]
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"role": UserRole.BASIC,
**({"is_active": True} if not user.is_active else {}),
},
)
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
@@ -848,16 +836,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
event=MilestoneRecordType.TENANT_CREATED,
)
# Assign user to the appropriate default group (Admin or Basic).
# Must happen inside the try block while tenant context is active,
# otherwise get_session_with_current_tenant() targets the wrong schema.
is_admin = user_count == 1 or user.email in get_default_admin_user_emails()
with get_session_with_current_tenant() as db_session:
assign_user_to_default_groups__no_commit(
db_session, user, is_admin=is_admin
)
db_session.commit()
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@@ -1576,7 +1554,6 @@ def get_anonymous_user() -> User:
is_verified=True,
is_superuser=False,
role=UserRole.LIMITED,
account_type=AccountType.ANONYMOUS,
use_memories=False,
enable_memory_tool=False,
)

View File

@@ -1,19 +1,8 @@
import threading
import time
from collections.abc import Callable
from collections.abc import Generator
from queue import Empty
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
from onyx.tools.models import ToolCallInfo
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import wait_on_background
# Type alias for search doc deduplication key
# Simple key: just document_id (str)
@@ -159,114 +148,3 @@ class ChatStateContainer:
"""Thread-safe getter for emitted citations (returns a copy)."""
with self._lock:
return self._emitted_citations.copy()
def run_chat_loop_with_state_containers(
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
completion_callback: Callable[[ChatStateContainer], None],
is_connected: Callable[[], bool],
emitter: Emitter,
state_container: ChatStateContainer,
) -> Generator[Packet, None]:
"""
Explicit wrapper function that runs a function in a background thread
with event streaming capabilities.
The wrapped function should accept emitter as first arg and use it to emit
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
Args:
func: The function to wrap (should accept emitter and state_container as first and second args)
completion_callback: Callback function to call when the function completes
emitter: Emitter instance for sending packets
state_container: ChatStateContainer instance for accumulating state
is_connected: Callable that returns False when stop signal is set
Usage:
packets = run_chat_loop_with_state_containers(
my_func,
completion_callback=completion_callback,
emitter=emitter,
state_container=state_container,
is_connected=check_func,
)
for packet in packets:
# Process packets
pass
"""
def run_with_exception_capture() -> None:
try:
chat_loop_func(emitter, state_container)
except Exception as e:
# If execution fails, emit an exception packet
emitter.emit(
Packet(
placement=Placement(turn_index=0),
obj=PacketException(type="error", exception=e),
)
)
# Run the function in a background thread
thread = run_in_background(run_with_exception_capture)
pkt: Packet | None = None
last_turn_index = 0 # Track the highest turn_index seen for stop packet
last_cancel_check = time.monotonic()
cancel_check_interval = 0.3 # Check for cancellation every 300ms
try:
while True:
# Poll queue with 300ms timeout for natural stop signal checking
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
try:
pkt = emitter.bus.get(timeout=0.3)
except Empty:
if not is_connected():
# Stop signal detected
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = time.monotonic()
continue
if pkt is not None:
# Track the highest turn_index for the stop packet
if pkt.placement and pkt.placement.turn_index > last_turn_index:
last_turn_index = pkt.placement.turn_index
if isinstance(pkt.obj, OverallStop):
yield pkt
break
elif isinstance(pkt.obj, PacketException):
raise pkt.obj.exception
else:
yield pkt
# Check for cancellation periodically even when packets are flowing
# This ensures stop signal is checked during active streaming
current_time = time.monotonic()
if current_time - last_cancel_check >= cancel_check_interval:
if not is_connected():
# Stop signal detected during streaming
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = current_time
finally:
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
# Skip waiting if user disconnected to exit quickly.
if is_connected():
wait_on_background(thread)
try:
completion_callback(state_container)
except Exception as e:
emitter.emit(
Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=PacketException(type="error", exception=e),
)
)

View File

@@ -5,6 +5,7 @@ from typing import cast
from uuid import UUID
from fastapi.datastructures import Headers
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.chat.models import ChatHistoryResult
@@ -51,6 +52,60 @@ logger = setup_logger()
IMAGE_GENERATION_TOOL_NAME = "generate_image"
class FileContextResult(BaseModel):
"""Result of building a file's LLM context representation."""
message: ChatMessageSimple
tool_metadata: FileToolMetadata
def build_file_context(
tool_file_id: str,
filename: str,
file_type: ChatFileType,
content_text: str | None = None,
token_count: int = 0,
approx_char_count: int | None = None,
) -> FileContextResult:
"""Build the LLM context representation for a single file.
Centralises how files should appear in the LLM prompt
— the ID that FileReaderTool accepts (``UserFile.id`` for user files).
"""
if file_type.use_metadata_only():
message_text = (
f"File: {filename} (id={tool_file_id})\n"
"Use the file_reader or python tools to access "
"this file's contents."
)
message = ChatMessageSimple(
message=message_text,
token_count=max(1, len(message_text) // 4),
message_type=MessageType.USER,
file_id=tool_file_id,
)
else:
message_text = f"File: {filename}\n{content_text or ''}\nEnd of File"
message = ChatMessageSimple(
message=message_text,
token_count=token_count,
message_type=MessageType.USER,
file_id=tool_file_id,
)
metadata = FileToolMetadata(
file_id=tool_file_id,
filename=filename,
approx_char_count=(
approx_char_count
if approx_char_count is not None
else len(content_text or "")
),
)
return FileContextResult(message=message, tool_metadata=metadata)
def create_chat_session_from_request(
chat_session_request: ChatSessionCreationRequest,
user_id: UUID | None,
@@ -538,7 +593,7 @@ def convert_chat_history(
for idx, chat_message in enumerate(chat_history):
if chat_message.message_type == MessageType.USER:
# Process files attached to this message
text_files: list[ChatLoadedFile] = []
text_files: list[tuple[ChatLoadedFile, FileDescriptor]] = []
image_files: list[ChatLoadedFile] = []
if chat_message.files:
@@ -549,34 +604,26 @@ def convert_chat_history(
if loaded_file.file_type == ChatFileType.IMAGE:
image_files.append(loaded_file)
else:
# Text files (DOC, PLAIN_TEXT, CSV) are added as separate messages
text_files.append(loaded_file)
# Text files (DOC, PLAIN_TEXT, TABULAR) are added as separate messages
text_files.append((loaded_file, file_descriptor))
# Add text files as separate messages before the user message.
# Each message is tagged with ``file_id`` so that forgotten files
# can be detected after context-window truncation.
for text_file in text_files:
file_text = text_file.content_text or ""
filename = text_file.filename
message = (
f"File: {filename}\n{file_text}\nEnd of File"
if filename
else file_text
)
simple_messages.append(
ChatMessageSimple(
message=message,
token_count=text_file.token_count,
message_type=MessageType.USER,
image_files=None,
file_id=text_file.file_id,
)
)
all_injected_file_metadata[text_file.file_id] = FileToolMetadata(
file_id=text_file.file_id,
filename=filename or "unknown",
approx_char_count=len(file_text),
for text_file, fd in text_files:
# Use user_file_id as the FileReaderTool accepts that.
# Fall back to the file-store path id.
tool_id = fd.get("user_file_id") or text_file.file_id
filename = text_file.filename or "unknown"
ctx = build_file_context(
tool_file_id=tool_id,
filename=filename,
file_type=text_file.file_type,
content_text=text_file.content_text,
token_count=text_file.token_count,
)
simple_messages.append(ctx.message)
all_injected_file_metadata[tool_id] = ctx.tool_metadata
# Sum token counts from image files (excluding project image files)
image_token_count = (

View File

@@ -1,19 +1,40 @@
import threading
from queue import Queue
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
class Emitter:
"""Use this inside tools to emit arbitrary UI progress."""
"""Routes packets from LLM/tool execution to the ``_run_models`` drain loop.
def __init__(self, bus: Queue):
self.bus = bus
Tags every packet with ``model_index`` and places it on ``merged_queue``
as a ``(model_idx, packet)`` tuple for ordered consumption downstream.
Args:
merged_queue: Shared queue owned by ``_run_models``.
model_idx: Index embedded in packet placements (``0`` for N=1 runs).
drain_done: Optional event set by ``_run_models`` when the drain loop
exits early (e.g. HTTP disconnect). When set, ``emit`` returns
immediately so worker threads can exit fast.
"""
def __init__(
self,
merged_queue: Queue[tuple[int, Packet | Exception | object]],
model_idx: int = 0,
drain_done: threading.Event | None = None,
) -> None:
self._model_idx = model_idx
self._merged_queue = merged_queue
self._drain_done = drain_done
def emit(self, packet: Packet) -> None:
self.bus.put(packet) # Thread-safe
def get_default_emitter() -> Emitter:
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
return emitter
if self._drain_done is not None and self._drain_done.is_set():
return
base = packet.placement or Placement(turn_index=0)
tagged = Packet(
placement=base.model_copy(update={"model_index": self._model_idx}),
obj=packet.obj,
)
self._merged_queue.put((self._model_idx, tagged))

File diff suppressed because it is too large Load Diff

View File

@@ -278,7 +278,6 @@ class NotificationType(str, Enum):
RELEASE_NOTES = "release_notes"
ASSISTANT_FILES_READY = "assistant_files_ready"
FEATURE_ANNOUNCEMENT = "feature_announcement"
USER_GROUP_ASSIGNMENT_FAILED = "user_group_assignment_failed"
class BlobType(str, Enum):

View File

@@ -11,19 +11,14 @@ from onyx.auth.api_key import ApiKeyDescriptor
from onyx.auth.api_key import build_displayable_api_key
from onyx.auth.api_key import generate_api_key
from onyx.auth.api_key import hash_api_key
from onyx.auth.schemas import UserRole
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.enums import AccountType
from onyx.db.models import ApiKey
from onyx.db.models import User
from onyx.server.api_key.models import APIKeyArgs
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def get_api_key_email_pattern() -> str:
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
@@ -92,7 +87,6 @@ def insert_api_key(
is_superuser=False,
is_verified=True,
role=api_key_args.role,
account_type=AccountType.SERVICE_ACCOUNT,
)
db_session.add(api_key_user_row)
@@ -105,21 +99,7 @@ def insert_api_key(
)
db_session.add(api_key_row)
# Assign the API key virtual user to the appropriate default group
# before commit so everything is atomic.
# LIMITED role service accounts should have no group membership.
# Late import to avoid circular dependency (api_key <- users <- api_key).
if api_key_args.role != UserRole.LIMITED:
from onyx.db.users import assign_user_to_default_groups__no_commit
assign_user_to_default_groups__no_commit(
db_session,
api_key_user_row,
is_admin=(api_key_args.role == UserRole.ADMIN),
)
db_session.commit()
return ApiKeyDescriptor(
api_key_id=api_key_row.id,
api_key_role=api_key_user_row.role,

View File

@@ -8,6 +8,7 @@ from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import nullsfirst
from sqlalchemy import or_
@@ -131,47 +132,32 @@ def get_chat_sessions_by_user(
if before is not None:
stmt = stmt.where(ChatSession.time_updated < before)
if limit:
stmt = stmt.limit(limit)
if project_id is not None:
stmt = stmt.where(ChatSession.project_id == project_id)
elif only_non_project_chats:
stmt = stmt.where(ChatSession.project_id.is_(None))
# When filtering out failed chats, we apply the limit in Python after
# filtering rather than in SQL, since the post-filter may remove rows.
if limit and include_failed_chats:
stmt = stmt.limit(limit)
if not include_failed_chats:
non_system_message_exists_subq = (
exists()
.where(ChatMessage.chat_session_id == ChatSession.id)
.where(ChatMessage.message_type != MessageType.SYSTEM)
.correlate(ChatSession)
)
# Leeway for newly created chats that don't have messages yet
time = datetime.now(timezone.utc) - timedelta(minutes=5)
recently_created = ChatSession.time_created >= time
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
result = db_session.execute(stmt)
chat_sessions = list(result.scalars().all())
chat_sessions = result.scalars().all()
if not include_failed_chats and chat_sessions:
# Filter out "failed" sessions (those with only SYSTEM messages)
# using a separate efficient query instead of a correlated EXISTS
# subquery, which causes full sequential scans of chat_message.
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
if session_ids:
valid_session_ids_stmt = (
select(ChatMessage.chat_session_id)
.where(ChatMessage.chat_session_id.in_(session_ids))
.where(ChatMessage.message_type != MessageType.SYSTEM)
.distinct()
)
valid_session_ids = set(
db_session.execute(valid_session_ids_stmt).scalars().all()
)
chat_sessions = [
cs
for cs in chat_sessions
if cs.time_created >= leeway or cs.id in valid_session_ids
]
if limit:
chat_sessions = chat_sessions[:limit]
return chat_sessions
return list(chat_sessions)
def delete_orphaned_search_docs(db_session: Session) -> None:
@@ -631,6 +617,92 @@ def reserve_message_id(
return empty_message
def reserve_multi_model_message_ids(
db_session: Session,
chat_session_id: UUID,
parent_message_id: int,
model_display_names: list[str],
) -> list[ChatMessage]:
"""Reserve N assistant message placeholders for multi-model parallel streaming.
All messages share the same parent (the user message). The parent's
latest_child_message_id points to the LAST reserved message so that the
default history-chain walker picks it up.
"""
reserved: list[ChatMessage] = []
for display_name in model_display_names:
msg = ChatMessage(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
latest_child_message_id=None,
message="Response was terminated prior to completion, try regenerating.",
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
message_type=MessageType.ASSISTANT,
model_display_name=display_name,
)
db_session.add(msg)
reserved.append(msg)
# Flush to assign IDs without committing yet
db_session.flush()
# Point parent's latest_child to the last reserved message
parent = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == parent_message_id)
.first()
)
if parent:
parent.latest_child_message_id = reserved[-1].id
db_session.commit()
return reserved
def set_preferred_response(
db_session: Session,
user_message_id: int,
preferred_assistant_message_id: int,
) -> None:
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
Also advances ``latest_child_message_id`` so the preferred response becomes
the active branch for any subsequent messages in the conversation.
Args:
db_session: Active database session.
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
preferred response is being set.
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
Raises:
ValueError: If either message is not found, if ``user_message_id`` does not
refer to a USER message, or if the assistant message is not a direct child
of the user message.
"""
user_msg = db_session.get(ChatMessage, user_message_id)
if user_msg is None:
raise ValueError(f"User message {user_message_id} not found")
if user_msg.message_type != MessageType.USER:
raise ValueError(f"Message {user_message_id} is not a user message")
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
if assistant_msg is None:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} not found"
)
if assistant_msg.parent_message_id != user_message_id:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} is not a child "
f"of user message {user_message_id}"
)
user_msg.preferred_response_id = preferred_assistant_message_id
user_msg.latest_child_message_id = preferred_assistant_message_id
db_session.commit()
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
@@ -853,6 +925,8 @@ def translate_db_message_to_chat_message_detail(
error=chat_message.error,
current_feedback=current_feedback,
processing_duration_seconds=chat_message.processing_duration_seconds,
preferred_response_id=chat_message.preferred_response_id,
model_display_name=chat_message.model_display_name,
)
return chat_msg_detail

View File

@@ -13,19 +13,19 @@ class AccountType(str, PyEnum):
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
"""
STANDARD = "STANDARD"
BOT = "BOT"
EXT_PERM_USER = "EXT_PERM_USER"
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
ANONYMOUS = "ANONYMOUS"
STANDARD = "standard"
BOT = "bot"
EXT_PERM_USER = "ext_perm_user"
SERVICE_ACCOUNT = "service_account"
ANONYMOUS = "anonymous"
class GrantSource(str, PyEnum):
"""How a permission grant was created."""
USER = "USER"
SCIM = "SCIM"
SYSTEM = "SYSTEM"
USER = "user"
SCIM = "scim"
SYSTEM = "system"
class IndexingStatus(str, PyEnum):

View File

@@ -305,11 +305,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
)
account_type: Mapped[AccountType] = mapped_column(
Enum(AccountType, native_enum=False),
nullable=False,
default=AccountType.STANDARD,
server_default="STANDARD",
account_type: Mapped[AccountType | None] = mapped_column(
Enum(AccountType, native_enum=False), nullable=True
)
"""
@@ -356,13 +353,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
postgresql.JSONB(), nullable=True, default=None
)
effective_permissions: Mapped[list[str]] = mapped_column(
postgresql.JSONB(),
nullable=False,
default=list,
server_default=text("'[]'::jsonb"),
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
TIMESTAMPAware(timezone=True), nullable=True
)
@@ -4026,12 +4016,7 @@ class PermissionGrant(Base):
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
)
permission: Mapped[Permission] = mapped_column(
Enum(
Permission,
native_enum=False,
values_callable=lambda x: [e.value for e in x],
),
nullable=False,
Enum(Permission, native_enum=False), nullable=False
)
grant_source: Mapped[GrantSource] = mapped_column(
Enum(GrantSource, native_enum=False), nullable=False

View File

@@ -3,7 +3,6 @@ from datetime import timezone
from uuid import UUID
from sqlalchemy import cast
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import insert
@@ -91,18 +90,9 @@ def get_notifications(
notif_type: NotificationType | None = None,
include_dismissed: bool = True,
) -> list[Notification]:
if user is None:
user_filter = Notification.user_id.is_(None)
elif user.role == UserRole.ADMIN:
# Admins see their own notifications AND admin-targeted ones (user_id IS NULL)
user_filter = or_(
Notification.user_id == user.id,
Notification.user_id.is_(None),
)
else:
user_filter = Notification.user_id == user.id
query = select(Notification).where(user_filter)
query = select(Notification).where(
Notification.user_id == user.id if user else Notification.user_id.is_(None)
)
if not include_dismissed:
query = query.where(Notification.dismissed.is_(False))
if notif_type:

View File

@@ -1,97 +0,0 @@
"""
DB operations for recomputing user effective_permissions.
These live in onyx/db/ (not onyx/auth/) because they are pure DB operations
that query PermissionGrant rows and update the User.effective_permissions
JSONB column. Keeping them here avoids circular imports when called from
other onyx/db/ modules such as users.py.
"""
from collections import defaultdict
from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.db.models import PermissionGrant
from onyx.db.models import User
from onyx.db.models import User__UserGroup
def recompute_user_permissions__no_commit(user_id: UUID, db_session: Session) -> None:
"""Recompute a single user's granted permissions from their group grants.
Stores only directly granted permissions — implication expansion
happens at read time via get_effective_permissions().
Does NOT commit — caller must commit the session.
"""
stmt = (
select(PermissionGrant.permission)
.join(
User__UserGroup,
PermissionGrant.group_id == User__UserGroup.user_group_id,
)
.where(
User__UserGroup.user_id == user_id,
PermissionGrant.is_deleted.is_(False),
)
)
rows = db_session.execute(stmt).scalars().all()
# sorted for consistent ordering in DB — easier to read when debugging
granted = sorted({p.value for p in rows})
db_session.execute(
update(User).where(User.id == user_id).values(effective_permissions=granted)
)
def recompute_permissions_for_group__no_commit(
group_id: int, db_session: Session
) -> None:
"""Recompute granted permissions for all users in a group.
Does NOT commit — caller must commit the session.
"""
user_ids: list[UUID] = list(
db_session.execute(
select(User__UserGroup.user_id).where(
User__UserGroup.user_group_id == group_id,
User__UserGroup.user_id.isnot(None),
)
)
.scalars()
.all()
)
if not user_ids:
return
# Single query to fetch ALL permissions for these users across ALL their
# groups (a user may belong to multiple groups with different grants).
rows = db_session.execute(
select(User__UserGroup.user_id, PermissionGrant.permission)
.join(
PermissionGrant,
PermissionGrant.group_id == User__UserGroup.user_group_id,
)
.where(
User__UserGroup.user_id.in_(user_ids),
PermissionGrant.is_deleted.is_(False),
)
).all()
# Group permissions by user; users with no grants get an empty set.
perms_by_user: dict[UUID, set[str]] = defaultdict(set)
for uid in user_ids:
perms_by_user[uid] # ensure every user has an entry
for uid, perm in rows:
perms_by_user[uid].add(perm.value)
for uid, perms in perms_by_user.items():
db_session.execute(
update(User)
.where(User.id == uid)
.values(effective_permissions=sorted(perms))
)

View File

@@ -19,7 +19,6 @@ from onyx.auth.schemas import UserRole
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.db.enums import AccountType
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__User
from onyx.db.models import Persona
@@ -28,11 +27,8 @@ from onyx.db.models import SamlAccount
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
def validate_user_role_update(
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
@@ -302,7 +298,6 @@ def _generate_slack_user(email: str) -> User:
email=email,
hashed_password=hashed_pass,
role=UserRole.SLACK_USER,
account_type=AccountType.BOT,
)
@@ -313,7 +308,6 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
# If the user is an external permissioned user, we update it to a slack user
if user.role == UserRole.EXT_PERM_USER:
user.role = UserRole.SLACK_USER
user.account_type = AccountType.BOT
db_session.commit()
return user
@@ -350,7 +344,6 @@ def _generate_ext_permissioned_user(email: str) -> User:
email=email,
hashed_password=hashed_pass,
role=UserRole.EXT_PERM_USER,
account_type=AccountType.EXT_PERM_USER,
)
@@ -382,81 +375,6 @@ def batch_add_ext_perm_user_if_not_exists(
return all_users
def assign_user_to_default_groups__no_commit(
db_session: Session,
user: User,
is_admin: bool = False,
) -> None:
"""Assign a newly created user to the appropriate default group.
Does NOT commit — callers must commit the session themselves so that
group assignment can be part of the same transaction as user creation.
Args:
is_admin: If True, assign to Admin default group; otherwise Basic.
Callers determine this from their own context (e.g. user_count,
admin email list, explicit choice). Defaults to False (Basic).
"""
if user.account_type in (
AccountType.BOT,
AccountType.EXT_PERM_USER,
AccountType.ANONYMOUS,
):
return
target_group_name = "Admin" if is_admin else "Basic"
default_group = (
db_session.query(UserGroup)
.filter(
UserGroup.name == target_group_name,
UserGroup.is_default.is_(True),
)
.first()
)
if default_group is None:
raise RuntimeError(
f"Default group '{target_group_name}' not found. "
f"Cannot assign user {user.email} to a group. "
f"Ensure the seed_default_groups migration has run."
)
# Check if the user is already in the group
existing = (
db_session.query(User__UserGroup)
.filter(
User__UserGroup.user_id == user.id,
User__UserGroup.user_group_id == default_group.id,
)
.first()
)
if existing is not None:
return
savepoint = db_session.begin_nested()
try:
db_session.add(
User__UserGroup(
user_id=user.id,
user_group_id=default_group.id,
)
)
db_session.flush()
except IntegrityError:
# Race condition: another transaction inserted this membership
# between our SELECT and INSERT. The savepoint isolates the failure
# so the outer transaction (user creation) stays intact.
savepoint.rollback()
return
from onyx.db.permissions import recompute_user_permissions__no_commit
recompute_user_permissions__no_commit(user.id, db_session)
logger.info(f"Assigned user {user.email} to default group '{default_group.name}'")
def delete_user_from_db(
user_to_delete: User,
db_session: Session,
@@ -503,14 +421,13 @@ def delete_user_from_db(
def batch_get_user_groups(
db_session: Session,
user_ids: list[UUID],
include_default: bool = False,
) -> dict[UUID, list[tuple[int, str]]]:
"""Fetch group memberships for a batch of users in a single query.
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
if not user_ids:
return {}
stmt = (
rows = db_session.execute(
select(
User__UserGroup.user_id,
UserGroup.id,
@@ -518,11 +435,7 @@ def batch_get_user_groups(
)
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
.where(User__UserGroup.user_id.in_(user_ids))
)
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
rows = db_session.execute(stmt).all()
).all()
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
for user_id, group_id, group_name in rows:

View File

@@ -23,6 +23,11 @@ class ChatFileType(str, Enum):
ChatFileType.TABULAR,
)
def use_metadata_only(self) -> bool:
"""File types where we can ignore the file content
and only use the metadata."""
return self in (ChatFileType.TABULAR,)
class FileDescriptor(TypedDict):
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column

View File

@@ -110,16 +110,20 @@ def load_user_file(file_id: UUID, db_session: Session) -> InMemoryChatFile:
# check for plain text normalized version first, then use original file otherwise
try:
file_io = file_store.read_file(plaintext_file_name, mode="b")
# For plaintext versions, use PLAIN_TEXT type (unless it's an image which doesn't have plaintext)
plaintext_chat_file_type = (
ChatFileType.PLAIN_TEXT
if chat_file_type != ChatFileType.IMAGE
else chat_file_type
)
# if we have plaintext for image (which happens when image extraction is enabled), we use PLAIN_TEXT type
if file_io is not None:
# Metadata-only file types preserve their original type so
# downstream injection paths can route them correctly.
if chat_file_type.use_metadata_only():
plaintext_chat_file_type = chat_file_type
elif file_io is not None:
# if we have plaintext for image (which happens when image
# extraction is enabled), we use PLAIN_TEXT type
plaintext_chat_file_type = ChatFileType.PLAIN_TEXT
else:
plaintext_chat_file_type = (
ChatFileType.PLAIN_TEXT
if chat_file_type != ChatFileType.IMAGE
else chat_file_type
)
chat_file = InMemoryChatFile(
file_id=str(user_file.file_id),

View File

@@ -8,6 +8,24 @@ from pydantic import BaseModel
class LLMOverride(BaseModel):
"""Per-request LLM settings that override persona defaults.
All fields are optional — only the fields that differ from the persona's
configured LLM need to be supplied. Used both over the wire (API requests)
and for multi-model comparison, where one override is supplied per model.
Attributes:
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
When ``None``, the persona's default provider is used.
model_version: Specific model version string (e.g. ``"gpt-4o"``).
When ``None``, the persona's default model is used.
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
persona's default temperature is used.
display_name: Human-readable label shown in the UI for this model,
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
when not set.
"""
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None

View File

@@ -27,7 +27,6 @@ from onyx.auth.email_utils import send_user_email_invite
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import remove_user_from_invited_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.permissions import get_effective_permissions
from onyx.auth.schemas import UserRole
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
@@ -774,13 +773,6 @@ def _get_token_created_at(
return get_current_token_creation_postgres(user, db_session)
@router.get("/me/permissions", tags=PUBLIC_API_TAGS)
def get_current_user_permissions(
user: User = Depends(current_user),
) -> list[str]:
return sorted(p.value for p in get_effective_permissions(user))
@router.get("/me", tags=PUBLIC_API_TAGS)
def verify_user_logged_in(
request: Request,

View File

@@ -7,7 +7,6 @@ from uuid import UUID
from pydantic import BaseModel
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from onyx.db.models import User
@@ -42,7 +41,6 @@ class FullUserSnapshot(BaseModel):
id: UUID
email: str
role: UserRole
account_type: AccountType
is_active: bool
password_configured: bool
personal_name: str | None
@@ -62,7 +60,6 @@ class FullUserSnapshot(BaseModel):
id=user.id,
email=user.email,
role=user.role,
account_type=user.account_type,
is_active=user.is_active,
password_configured=user.password_configured,
personal_name=user.personal_name,

View File

@@ -28,6 +28,7 @@ from onyx.chat.chat_utils import extract_headers
from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_multi_model_stream
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import set_preferred_response
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.session_loading import (
@@ -570,6 +575,46 @@ def handle_send_chat_message(
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
is_multi_model = (
chat_message_req.llm_overrides is not None
and len(chat_message_req.llm_overrides) > 1
)
if is_multi_model and chat_message_req.stream:
# Narrowed here; is_multi_model already checked llm_overrides is not None
llm_overrides = chat_message_req.llm_overrides or []
def multi_model_stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in handle_multi_model_stream(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in multi-model streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(
multi_model_stream_generator(), media_type="text/event-stream"
)
if is_multi_model and not chat_message_req.stream:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
)
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -660,6 +705,30 @@ def set_message_as_latest(
)
@router.put("/set-preferred-response")
def set_preferred_response_endpoint(
request_body: SetPreferredResponseRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
"""Set the preferred assistant response for a multi-model turn."""
try:
# Ownership check: get_chat_message raises ValueError if the message
# doesn't belong to this user, preventing cross-user mutation.
get_chat_message(
chat_message_id=request_body.user_message_id,
user_id=user.id if user else None,
db_session=db_session,
)
set_preferred_response(
db_session=db_session,
user_message_id=request_body.user_message_id,
preferred_assistant_message_id=request_body.preferred_response_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,

View File

@@ -2,11 +2,25 @@ from pydantic import BaseModel
class Placement(BaseModel):
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
"""Coordinates that identify where a streaming packet belongs in the UI.
The frontend uses these fields to route each packet to the correct turn,
tool tab, agent sub-turn, and (in multi-model mode) response column.
Attributes:
turn_index: Monotonically increasing index of the iterative reasoning block
(e.g. tool call round) within this chat message. Lower values happened first.
tab_index: Disambiguates parallel tool calls within the same turn so each
tool's output can be displayed in its own tab.
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
top-level packets; an integer for tool-within-tool output.
model_index: Which model this packet belongs to. ``0`` for single-model
responses; ``0``, ``1``, or ``2`` for multi-model comparison. ``None``
for pre-LLM setup packets (e.g. message ID info) that are yielded
before any Emitter runs.
"""
turn_index: int
# For parallel tool calls to preserve order of execution
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
model_index: int | None = None

View File

@@ -1,3 +1,4 @@
import queue
import time
from collections.abc import Callable
from typing import Any
@@ -708,7 +709,6 @@ def run_research_agent_calls(
if __name__ == "__main__":
from queue import Queue
from uuid import uuid4
from onyx.chat.chat_state import ChatStateContainer
@@ -744,8 +744,8 @@ if __name__ == "__main__":
if user is None:
raise ValueError("No users found in database. Please create a user first.")
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
emitter_queue: queue.Queue = queue.Queue()
emitter = Emitter(merged_queue=emitter_queue)
state_container = ChatStateContainer()
tool_dict = construct_tools(
@@ -792,4 +792,4 @@ if __name__ == "__main__":
print(result.intermediate_report)
print("=" * 80)
print(f"Citations: {result.citation_mapping}")
print(f"Total packets emitted: {bus.qsize()}")
print(f"Total packets emitted: {emitter_queue.qsize()}")

View File

@@ -1,5 +1,6 @@
import csv
import json
import queue
import uuid
from io import BytesIO
from io import StringIO
@@ -11,7 +12,6 @@ import requests
from requests import JSONDecodeError
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.configs.constants import FileOrigin
from onyx.file_store.file_store import get_default_file_store
from onyx.server.query_and_chat.placement import Placement
@@ -296,9 +296,9 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
# Use default emitter if none provided
# Use a discard emitter if none provided (packets go nowhere)
if emitter is None:
emitter = get_default_emitter()
emitter = Emitter(merged_queue=queue.Queue())
return [
CustomTool(
@@ -367,7 +367,7 @@ if __name__ == "__main__":
tools = build_custom_tools_from_openapi_schema_and_headers(
tool_id=0, # dummy tool id
openapi_schema=openapi_schema,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
dynamic_schema_info=None,
)

View File

@@ -27,11 +27,13 @@ def create_placement(
turn_index: int,
tab_index: int = 0,
sub_turn_index: int | None = None,
model_index: int | None = 0,
) -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
model_index=model_index,
)

View File

@@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.enums import AccountType
from onyx.db.models import User
from onyx.db.models import UserRole
from onyx.file_store.file_store import get_default_file_store
@@ -53,12 +52,7 @@ def tenant_context() -> Generator[None, None, None]:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def create_test_user(
db_session: Session,
email_prefix: str,
role: UserRole = UserRole.BASIC,
account_type: AccountType = AccountType.STANDARD,
) -> User:
def create_test_user(db_session: Session, email_prefix: str) -> User:
"""Helper to create a test user with a unique email"""
# Use UUID to ensure unique email addresses
unique_email = f"{email_prefix}_{uuid4().hex[:8]}@example.com"
@@ -74,8 +68,7 @@ def create_test_user(
is_active=True,
is_superuser=False,
is_verified=True,
role=role,
account_type=account_type,
role=UserRole.EXT_PERM_USER,
)
db_session.add(user)
db_session.commit()

View File

@@ -13,29 +13,16 @@ from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import InputType
from onyx.db.enums import AccessType
from onyx.db.enums import AccountType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import PublicExternalUserGroup
from onyx.db.models import User
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.models import UserRole
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
def _create_ext_perm_user(db_session: Session, name: str) -> User:
"""Create an external-permission user for group sync tests."""
return create_test_user(
db_session,
name,
role=UserRole.EXT_PERM_USER,
account_type=AccountType.EXT_PERM_USER,
)
def _create_test_connector_credential_pair(
db_session: Session, source: DocumentSource = DocumentSource.GOOGLE_DRIVE
) -> ConnectorCredentialPair:
@@ -113,9 +100,9 @@ class TestPerformExternalGroupSync:
def test_initial_group_sync(self, db_session: Session) -> None:
"""Test syncing external groups for the first time (initial sync)"""
# Create test data
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
user3 = _create_ext_perm_user(db_session, "user3")
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
user3 = create_test_user(db_session, "user3")
cc_pair = _create_test_connector_credential_pair(db_session)
# Mock external groups data as a generator that yields the expected groups
@@ -188,9 +175,9 @@ class TestPerformExternalGroupSync:
def test_update_existing_groups(self, db_session: Session) -> None:
"""Test updating existing groups (adding/removing users)"""
# Create test data
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
user3 = _create_ext_perm_user(db_session, "user3")
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
user3 = create_test_user(db_session, "user3")
cc_pair = _create_test_connector_credential_pair(db_session)
# Initial sync with original groups
@@ -285,8 +272,8 @@ class TestPerformExternalGroupSync:
def test_remove_groups(self, db_session: Session) -> None:
"""Test removing groups (groups that no longer exist in external system)"""
# Create test data
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
cc_pair = _create_test_connector_credential_pair(db_session)
# Initial sync with multiple groups
@@ -370,7 +357,7 @@ class TestPerformExternalGroupSync:
def test_empty_group_sync(self, db_session: Session) -> None:
"""Test syncing when no groups are returned (all groups removed)"""
# Create test data
user1 = _create_ext_perm_user(db_session, "user1")
user1 = create_test_user(db_session, "user1")
cc_pair = _create_test_connector_credential_pair(db_session)
# Initial sync with groups
@@ -426,7 +413,7 @@ class TestPerformExternalGroupSync:
# Create many test users
users = []
for i in range(150): # More than the batch size of 100
users.append(_create_ext_perm_user(db_session, f"user{i}"))
users.append(create_test_user(db_session, f"user{i}"))
cc_pair = _create_test_connector_credential_pair(db_session)
@@ -465,8 +452,8 @@ class TestPerformExternalGroupSync:
def test_mixed_regular_and_public_groups(self, db_session: Session) -> None:
"""Test syncing a mix of regular and public groups"""
# Create test data
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
cc_pair = _create_test_connector_credential_pair(db_session)
def mixed_group_sync_func(

View File

@@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.enums import AccountType
from onyx.db.enums import BuildSessionStatus
from onyx.db.models import BuildSession
from onyx.db.models import User
@@ -53,7 +52,6 @@ def test_user(db_session: Session, tenant_context: None) -> User: # noqa: ARG00
is_superuser=False,
is_verified=True,
role=UserRole.EXT_PERM_USER,
account_type=AccountType.EXT_PERM_USER,
)
db_session.add(user)
db_session.commit()

View File

@@ -1,51 +0,0 @@
"""
Tests that account_type is correctly set when creating users through
the internal DB functions: add_slack_user_if_not_exists and
batch_add_ext_perm_user_if_not_exists.
These functions are called by background workers (Slack bot, permission sync)
and are not exposed via API endpoints, so they must be tested directly.
"""
from sqlalchemy.orm import Session
from onyx.db.enums import AccountType
from onyx.db.models import UserRole
from onyx.db.users import add_slack_user_if_not_exists
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
def test_slack_user_creation_sets_account_type_bot(db_session: Session) -> None:
"""add_slack_user_if_not_exists sets account_type=BOT and role=SLACK_USER."""
user = add_slack_user_if_not_exists(db_session, "slack_acct_type@test.com")
assert user.role == UserRole.SLACK_USER
assert user.account_type == AccountType.BOT
def test_ext_perm_user_creation_sets_account_type(db_session: Session) -> None:
"""batch_add_ext_perm_user_if_not_exists sets account_type=EXT_PERM_USER."""
users = batch_add_ext_perm_user_if_not_exists(
db_session, ["extperm_acct_type@test.com"]
)
assert len(users) == 1
user = users[0]
assert user.role == UserRole.EXT_PERM_USER
assert user.account_type == AccountType.EXT_PERM_USER
def test_ext_perm_to_slack_upgrade_updates_role_and_account_type(
db_session: Session,
) -> None:
"""When an EXT_PERM_USER is upgraded to slack, both role and account_type update."""
email = "ext_to_slack_acct_type@test.com"
# Create as ext_perm user first
batch_add_ext_perm_user_if_not_exists(db_session, [email])
# Now "upgrade" via slack path
user = add_slack_user_if_not_exists(db_session, email)
assert user.role == UserRole.SLACK_USER
assert user.account_type == AccountType.BOT

View File

@@ -8,7 +8,6 @@ import pytest
from fastapi_users.password import PasswordHelper
from sqlalchemy.orm import Session
from onyx.db.enums import AccountType
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
@@ -47,7 +46,6 @@ def _create_admin(db_session: Session) -> User:
is_superuser=True,
is_verified=True,
role=UserRole.ADMIN,
account_type=AccountType.STANDARD,
)
db_session.add(user)
db_session.commit()

View File

@@ -13,6 +13,7 @@ This test:
All external HTTP calls are mocked, but Postgres and Redis are running.
"""
import queue
from typing import Any
from unittest.mock import patch
from uuid import uuid4
@@ -20,7 +21,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import get_default_emitter
from onyx.chat.emitter import Emitter
from onyx.db.enums import MCPAuthenticationPerformer
from onyx.db.enums import MCPAuthenticationType
from onyx.db.enums import MCPTransport
@@ -137,7 +138,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=search_tool_config,
@@ -200,7 +201,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -275,7 +276,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -350,7 +351,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -458,7 +459,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -541,7 +542,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),

View File

@@ -8,6 +8,7 @@ Tests the priority logic for OAuth tokens when constructing custom tools:
All external HTTP calls are mocked, but Postgres and Redis are running.
"""
import queue
from typing import Any
from unittest.mock import Mock
from unittest.mock import patch
@@ -16,7 +17,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import get_default_emitter
from onyx.chat.emitter import Emitter
from onyx.db.models import OAuthAccount
from onyx.db.models import OAuthConfig
from onyx.db.models import Persona
@@ -174,7 +175,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=search_tool_config,
@@ -232,7 +233,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -284,7 +285,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -345,7 +346,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -416,7 +417,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -483,7 +484,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -536,7 +537,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)

View File

@@ -126,15 +126,6 @@ class UserManager:
return test_user
@staticmethod
def get_permissions(user: DATestUser) -> list[str]:
response = requests.get(
url=f"{API_SERVER_URL}/me/permissions",
headers=user.headers,
)
response.raise_for_status()
return response.json()
@staticmethod
def is_role(
user_to_verify: DATestUser,

View File

@@ -104,30 +104,13 @@ class UserGroupManager:
)
response.raise_for_status()
@staticmethod
def get_permissions(
user_group: DATestUserGroup,
user_performing_action: DATestUser,
) -> list[str]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/permissions",
headers=user_performing_action.headers,
)
response.raise_for_status()
return response.json()
@staticmethod
def get_all(
user_performing_action: DATestUser,
include_default: bool = False,
) -> list[UserGroup]:
params: dict[str, str] = {}
if include_default:
params["include_default"] = "true"
response = requests.get(
f"{API_SERVER_URL}/manage/admin/user-group",
headers=user_performing_action.headers,
params=params,
)
response.raise_for_status()
return [UserGroup(**ug) for ug in response.json()]

View File

@@ -1,13 +1,9 @@
from uuid import UUID
import requests
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestUser
@@ -37,120 +33,3 @@ def test_limited(reset: None) -> None: # noqa: ARG001
headers=api_key.headers,
)
assert response.status_code == 403
def _get_service_account_account_type(
admin_user: DATestUser,
api_key_user_id: UUID,
) -> AccountType:
"""Fetch the account_type of a service account user via the user listing API."""
response = requests.get(
f"{API_SERVER_URL}/manage/users",
headers=admin_user.headers,
params={"include_api_keys": "true"},
)
response.raise_for_status()
data = response.json()
user_id_str = str(api_key_user_id)
for user in data["accepted"]:
if user["id"] == user_id_str:
return AccountType(user["account_type"])
raise AssertionError(
f"Service account user {user_id_str} not found in user listing"
)
def _get_default_group_user_ids(
admin_user: DATestUser,
) -> tuple[set[str], set[str]]:
"""Return (admin_group_user_ids, basic_group_user_ids) from default groups."""
all_groups = UserGroupManager.get_all(
user_performing_action=admin_user,
include_default=True,
)
admin_group = next(
(g for g in all_groups if g.name == "Admin" and g.is_default), None
)
basic_group = next(
(g for g in all_groups if g.name == "Basic" and g.is_default), None
)
assert admin_group is not None, "Admin default group not found"
assert basic_group is not None, "Basic default group not found"
admin_ids = {str(u.id) for u in admin_group.users}
basic_ids = {str(u.id) for u in basic_group.users}
return admin_ids, basic_ids
def test_api_key_limited_service_account(reset: None) -> None: # noqa: ARG001
"""LIMITED role API key: account_type is SERVICE_ACCOUNT, no group membership."""
admin_user: DATestUser = UserManager.create(name="admin_user")
api_key: DATestAPIKey = APIKeyManager.create(
api_key_role=UserRole.LIMITED,
user_performing_action=admin_user,
)
# Verify account_type
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
assert (
account_type == AccountType.SERVICE_ACCOUNT
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
# Verify no group membership
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
user_id_str = str(api_key.user_id)
assert (
user_id_str not in admin_ids
), "LIMITED API key should NOT be in Admin default group"
assert (
user_id_str not in basic_ids
), "LIMITED API key should NOT be in Basic default group"
def test_api_key_basic_service_account(reset: None) -> None: # noqa: ARG001
"""BASIC role API key: account_type is SERVICE_ACCOUNT, in Basic group only."""
admin_user: DATestUser = UserManager.create(name="admin_user")
api_key: DATestAPIKey = APIKeyManager.create(
api_key_role=UserRole.BASIC,
user_performing_action=admin_user,
)
# Verify account_type
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
assert (
account_type == AccountType.SERVICE_ACCOUNT
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
# Verify Basic group membership
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
user_id_str = str(api_key.user_id)
assert user_id_str in basic_ids, "BASIC API key should be in Basic default group"
assert (
user_id_str not in admin_ids
), "BASIC API key should NOT be in Admin default group"
def test_api_key_admin_service_account(reset: None) -> None: # noqa: ARG001
"""ADMIN role API key: account_type is SERVICE_ACCOUNT, in Admin group only."""
admin_user: DATestUser = UserManager.create(name="admin_user")
api_key: DATestAPIKey = APIKeyManager.create(
api_key_role=UserRole.ADMIN,
user_performing_action=admin_user,
)
# Verify account_type
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
assert (
account_type == AccountType.SERVICE_ACCOUNT
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
# Verify Admin group membership
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
user_id_str = str(api_key.user_id)
assert user_id_str in admin_ids, "ADMIN API key should be in Admin default group"
assert (
user_id_str not in basic_ids
), "ADMIN API key should NOT be in Basic default group"

View File

@@ -4,10 +4,8 @@ import pytest
import requests
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
@@ -97,63 +95,3 @@ def test_saml_user_conversion(reset: None) -> None: # noqa: ARG001
# Verify the user's role was changed in the database
assert UserManager.is_role(slack_user, UserRole.BASIC)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="SAML tests are enterprise only",
)
def test_saml_user_conversion_sets_account_type_and_group(
reset: None, # noqa: ARG001
) -> None:
"""
Test that SAML login sets account_type to STANDARD when converting a
non-web user (EXT_PERM_USER) and that the user receives the correct role
(BASIC) after conversion.
This validates the permissions-migration-phase2 changes which ensure that:
1. account_type is updated to 'standard' on SAML conversion
2. The converted user is assigned to the Basic default group
"""
# Create an admin user (first user is automatically admin)
admin_user: DATestUser = UserManager.create(email="admin@example.com")
# Create a user and set them as EXT_PERM_USER
test_email = "ext_convert@example.com"
test_user = UserManager.create(email=test_email)
UserManager.set_role(
user_to_set=test_user,
target_role=UserRole.EXT_PERM_USER,
user_performing_action=admin_user,
explicit_override=True,
)
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
# Simulate SAML login
response = requests.post(
f"{API_SERVER_URL}/manage/users/test-upsert-user",
json={"email": test_email},
headers=admin_user.headers,
)
response.raise_for_status()
user_data = response.json()
# Verify account_type is set to standard after conversion
assert (
user_data["account_type"] == AccountType.STANDARD.value
), f"Expected account_type='{AccountType.STANDARD.value}', got '{user_data['account_type']}'"
# Verify role is BASIC after conversion
assert user_data["role"] == UserRole.BASIC.value
# Verify the user was assigned to the Basic default group
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
basic_default = [g for g in all_groups if g.is_default and g.name == "Basic"]
assert basic_default, "Basic default group not found"
basic_group = basic_default[0]
member_emails = {u.email for u in basic_group.users}
assert test_email in member_emails, (
f"Converted user '{test_email}' not found in Basic default group members: "
f"{member_emails}"
)

View File

@@ -35,16 +35,9 @@ from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import REDIS_DB_NUMBER
from onyx.configs.app_configs import REDIS_HOST
from onyx.configs.app_configs import REDIS_PORT
from onyx.db.enums import AccountType
from onyx.server.settings.models import ApplicationStatus
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.scim_client import ScimClient
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
@@ -218,49 +211,6 @@ def test_create_user(scim_token: str, idp_style: str) -> None:
_assert_entra_emails(body, email)
def test_create_user_default_group_and_account_type(
scim_token: str, idp_style: str
) -> None:
"""SCIM-provisioned users get Basic default group and STANDARD account_type."""
email = f"scim_defaults_{idp_style}@example.com"
ext_id = f"ext-defaults-{idp_style}"
resp = _create_scim_user(scim_token, email, ext_id, idp_style)
assert resp.status_code == 201
user_id = resp.json()["id"]
# --- Verify group assignment via SCIM GET ---
get_resp = ScimClient.get(f"/Users/{user_id}", scim_token)
assert get_resp.status_code == 200
groups = get_resp.json().get("groups", [])
group_names = {g["display"] for g in groups}
assert "Basic" in group_names, f"Expected 'Basic' in groups, got {group_names}"
assert "Admin" not in group_names, "SCIM user should not be in Admin group"
# --- Verify account_type via admin API ---
admin = UserManager.login_as_user(
DATestUser(
id="",
email=build_email(ADMIN_USER_NAME),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.ADMIN,
is_active=True,
)
)
page = UserManager.get_user_page(
user_performing_action=admin,
search_query=email,
)
assert page.total_items >= 1
scim_user_snapshot = next((u for u in page.items if u.email == email), None)
assert (
scim_user_snapshot is not None
), f"SCIM user {email} not found in user listing"
assert (
scim_user_snapshot.account_type == AccountType.STANDARD
), f"Expected STANDARD, got {scim_user_snapshot.account_type}"
def test_get_user(scim_token: str, idp_style: str) -> None:
"""GET /Users/{id} returns the user resource with all stored fields."""
email = f"scim_get_{idp_style}@example.com"

View File

@@ -1,118 +0,0 @@
import os
import pytest
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import Permission
from onyx.db.models import PermissionGrant
from onyx.db.models import UserGroup as UserGroupModel
from onyx.db.permissions import recompute_permissions_for_group__no_commit
from onyx.db.permissions import recompute_user_permissions__no_commit
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="User group tests are enterprise only",
)
def test_user_gets_permissions_when_added_to_group(
reset: None, # noqa: ARG001
) -> None:
admin_user: DATestUser = UserManager.create(name="admin_for_perm_test")
basic_user: DATestUser = UserManager.create(name="basic_user_for_perm_test")
# basic_user starts with only "basic" from the default group
initial_permissions = UserManager.get_permissions(basic_user)
assert "basic" in initial_permissions
assert "add:agents" not in initial_permissions
# Create a new group and add basic_user
group = UserGroupManager.create(
name="perm-test-group",
user_ids=[admin_user.id, basic_user.id],
user_performing_action=admin_user,
)
# Grant a non-basic permission to the group and recompute
with get_session_with_current_tenant() as db_session:
db_group = db_session.get(UserGroupModel, group.id)
assert db_group is not None
db_session.add(
PermissionGrant(
group_id=db_group.id,
permission=Permission.ADD_AGENTS,
grant_source="SYSTEM",
)
)
db_session.flush()
recompute_user_permissions__no_commit(basic_user.id, db_session)
db_session.commit()
# Verify the user gained the new permission (expanded includes read:agents)
updated_permissions = UserManager.get_permissions(basic_user)
assert (
"add:agents" in updated_permissions
), f"User should have 'add:agents' after group grant, got: {updated_permissions}"
assert (
"read:agents" in updated_permissions
), f"User should have implied 'read:agents', got: {updated_permissions}"
assert "basic" in updated_permissions
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="User group tests are enterprise only",
)
def test_group_permission_change_propagates_to_all_members(
reset: None, # noqa: ARG001
) -> None:
admin_user: DATestUser = UserManager.create(name="admin_propagate")
user_a: DATestUser = UserManager.create(name="user_a_propagate")
user_b: DATestUser = UserManager.create(name="user_b_propagate")
group = UserGroupManager.create(
name="propagate-test-group",
user_ids=[admin_user.id, user_a.id, user_b.id],
user_performing_action=admin_user,
)
# Neither user should have add:agents yet
for u in (user_a, user_b):
assert "add:agents" not in UserManager.get_permissions(u)
# Grant add:agents to the group, then batch-recompute
with get_session_with_current_tenant() as db_session:
grant = PermissionGrant(
group_id=group.id,
permission=Permission.ADD_AGENTS,
grant_source="SYSTEM",
)
db_session.add(grant)
db_session.flush()
recompute_permissions_for_group__no_commit(group.id, db_session)
db_session.commit()
# Both users should now have the permission (plus implied read:agents)
for u in (user_a, user_b):
perms = UserManager.get_permissions(u)
assert "add:agents" in perms, f"{u.id} missing add:agents: {perms}"
assert "read:agents" in perms, f"{u.id} missing implied read:agents: {perms}"
# Soft-delete the grant and recompute — permission should be removed
with get_session_with_current_tenant() as db_session:
db_grant = (
db_session.query(PermissionGrant)
.filter_by(group_id=group.id, permission=Permission.ADD_AGENTS)
.first()
)
assert db_grant is not None
db_grant.is_deleted = True
db_session.flush()
recompute_permissions_for_group__no_commit(group.id, db_session)
db_session.commit()
for u in (user_a, user_b):
perms = UserManager.get_permissions(u)
assert "add:agents" not in perms, f"{u.id} still has add:agents: {perms}"

View File

@@ -1,30 +0,0 @@
import os
import pytest
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="User group tests are enterprise only",
)
def test_new_group_gets_basic_permission(reset: None) -> None: # noqa: ARG001
admin_user: DATestUser = UserManager.create(name="admin_for_basic_perm")
user_group = UserGroupManager.create(
name="basic-perm-test-group",
user_ids=[admin_user.id],
user_performing_action=admin_user,
)
permissions = UserGroupManager.get_permissions(
user_group=user_group,
user_performing_action=admin_user,
)
assert (
"basic" in permissions
), f"New group should have 'basic' permission, got: {permissions}"

View File

@@ -1,78 +0,0 @@
"""Integration tests for default group assignment on user registration.
Verifies that:
- The first registered user is assigned to the Admin default group
- Subsequent registered users are assigned to the Basic default group
- account_type is set to STANDARD for email/password registrations
"""
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
def test_default_group_assignment_on_registration(reset: None) -> None: # noqa: ARG001
# Register first user — should become admin
admin_user: DATestUser = UserManager.create(name="first_user")
assert admin_user.role == UserRole.ADMIN
# Register second user — should become basic
basic_user: DATestUser = UserManager.create(name="second_user")
assert basic_user.role == UserRole.BASIC
# Fetch all groups including default ones
all_groups = UserGroupManager.get_all(
user_performing_action=admin_user,
include_default=True,
)
# Find the default Admin and Basic groups
admin_group = next(
(g for g in all_groups if g.name == "Admin" and g.is_default), None
)
basic_group = next(
(g for g in all_groups if g.name == "Basic" and g.is_default), None
)
assert admin_group is not None, "Admin default group not found"
assert basic_group is not None, "Basic default group not found"
# Verify admin user is in Admin group and NOT in Basic group
admin_group_user_ids = {str(u.id) for u in admin_group.users}
basic_group_user_ids = {str(u.id) for u in basic_group.users}
assert (
admin_user.id in admin_group_user_ids
), "First user should be in Admin default group"
assert (
admin_user.id not in basic_group_user_ids
), "First user should NOT be in Basic default group"
# Verify basic user is in Basic group and NOT in Admin group
assert (
basic_user.id in basic_group_user_ids
), "Second user should be in Basic default group"
assert (
basic_user.id not in admin_group_user_ids
), "Second user should NOT be in Admin default group"
# Verify account_type is STANDARD for both users via user listing API
paginated_result = UserManager.get_user_page(
user_performing_action=admin_user,
page_num=0,
page_size=10,
)
users_by_id = {str(u.id): u for u in paginated_result.items}
admin_snapshot = users_by_id.get(admin_user.id)
basic_snapshot = users_by_id.get(basic_user.id)
assert admin_snapshot is not None, "Admin user not found in user listing"
assert basic_snapshot is not None, "Basic user not found in user listing"
assert (
admin_snapshot.account_type == AccountType.STANDARD
), f"Admin user account_type should be STANDARD, got {admin_snapshot.account_type}"
assert (
basic_snapshot.account_type == AccountType.STANDARD
), f"Basic user account_type should be STANDARD, got {basic_snapshot.account_type}"

View File

@@ -1,176 +0,0 @@
"""
Unit tests for onyx.auth.permissions — pure logic and FastAPI dependency.
"""
from unittest.mock import MagicMock
import pytest
from onyx.auth.permissions import ALL_PERMISSIONS
from onyx.auth.permissions import get_effective_permissions
from onyx.auth.permissions import require_permission
from onyx.auth.permissions import resolve_effective_permissions
from onyx.db.enums import Permission
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
# ---------------------------------------------------------------------------
# resolve_effective_permissions
# ---------------------------------------------------------------------------
class TestResolveEffectivePermissions:
def test_empty_set(self) -> None:
assert resolve_effective_permissions(set()) == set()
def test_basic_no_implications(self) -> None:
result = resolve_effective_permissions({"basic"})
assert result == {"basic"}
def test_single_implication(self) -> None:
result = resolve_effective_permissions({"add:agents"})
assert result == {"add:agents", "read:agents"}
def test_manage_agents_implies_add_and_read(self) -> None:
"""manage:agents directly maps to {add:agents, read:agents}."""
result = resolve_effective_permissions({"manage:agents"})
assert result == {"manage:agents", "add:agents", "read:agents"}
def test_manage_connectors_chain(self) -> None:
result = resolve_effective_permissions({"manage:connectors"})
assert result == {"manage:connectors", "add:connectors", "read:connectors"}
def test_manage_document_sets(self) -> None:
result = resolve_effective_permissions({"manage:document_sets"})
assert result == {
"manage:document_sets",
"read:document_sets",
"read:connectors",
}
def test_manage_user_groups_implies_all_reads(self) -> None:
result = resolve_effective_permissions({"manage:user_groups"})
assert result == {
"manage:user_groups",
"read:connectors",
"read:document_sets",
"read:agents",
"read:users",
}
def test_admin_override(self) -> None:
result = resolve_effective_permissions({"admin"})
assert result == set(ALL_PERMISSIONS)
def test_admin_with_others(self) -> None:
result = resolve_effective_permissions({"admin", "basic"})
assert result == set(ALL_PERMISSIONS)
def test_multi_group_union(self) -> None:
result = resolve_effective_permissions(
{"add:agents", "manage:connectors", "basic"}
)
assert result == {
"basic",
"add:agents",
"read:agents",
"manage:connectors",
"add:connectors",
"read:connectors",
}
def test_toggle_permission_no_implications(self) -> None:
result = resolve_effective_permissions({"read:agent_analytics"})
assert result == {"read:agent_analytics"}
def test_all_permissions_for_admin(self) -> None:
result = resolve_effective_permissions({"admin"})
assert len(result) == len(ALL_PERMISSIONS)
# ---------------------------------------------------------------------------
# get_effective_permissions (expands implied at read time)
# ---------------------------------------------------------------------------
class TestGetEffectivePermissions:
def test_expands_implied_permissions(self) -> None:
"""Column stores only granted; get_effective_permissions expands implied."""
user = MagicMock()
user.effective_permissions = ["add:agents"]
result = get_effective_permissions(user)
assert result == {Permission.ADD_AGENTS, Permission.READ_AGENTS}
def test_admin_expands_to_all(self) -> None:
user = MagicMock()
user.effective_permissions = ["admin"]
result = get_effective_permissions(user)
assert result == set(Permission)
def test_basic_stays_basic(self) -> None:
user = MagicMock()
user.effective_permissions = ["basic"]
result = get_effective_permissions(user)
assert result == {Permission.BASIC_ACCESS}
def test_empty_column(self) -> None:
user = MagicMock()
user.effective_permissions = []
result = get_effective_permissions(user)
assert result == set()
# ---------------------------------------------------------------------------
# require_permission (FastAPI dependency)
# ---------------------------------------------------------------------------
class TestRequirePermission:
@pytest.mark.asyncio
async def test_admin_bypass(self) -> None:
"""Admin stored in column should pass any permission check."""
user = MagicMock()
user.effective_permissions = ["admin"]
dep = require_permission(Permission.MANAGE_CONNECTORS)
result = await dep(user=user)
assert result is user
@pytest.mark.asyncio
async def test_has_required_permission(self) -> None:
user = MagicMock()
user.effective_permissions = ["manage:connectors"]
dep = require_permission(Permission.MANAGE_CONNECTORS)
result = await dep(user=user)
assert result is user
@pytest.mark.asyncio
async def test_implied_permission_passes(self) -> None:
"""manage:connectors implies read:connectors at read time."""
user = MagicMock()
user.effective_permissions = ["manage:connectors"]
dep = require_permission(Permission.READ_CONNECTORS)
result = await dep(user=user)
assert result is user
@pytest.mark.asyncio
async def test_missing_permission_raises(self) -> None:
user = MagicMock()
user.effective_permissions = ["basic"]
dep = require_permission(Permission.MANAGE_CONNECTORS)
with pytest.raises(OnyxError) as exc_info:
await dep(user=user)
assert exc_info.value.error_code == OnyxErrorCode.INSUFFICIENT_PERMISSIONS
@pytest.mark.asyncio
async def test_empty_permissions_fails(self) -> None:
user = MagicMock()
user.effective_permissions = []
dep = require_permission(Permission.BASIC_ACCESS)
with pytest.raises(OnyxError):
await dep(user=user)

View File

@@ -1,29 +0,0 @@
"""
Unit tests for UserCreate schema dict methods.
Verifies that account_type is always included in create_update_dict
and create_update_dict_superuser.
"""
from onyx.auth.schemas import UserCreate
from onyx.db.enums import AccountType
def test_create_update_dict_includes_default_account_type() -> None:
uc = UserCreate(email="a@b.com", password="secret123")
d = uc.create_update_dict()
assert d["account_type"] == AccountType.STANDARD
def test_create_update_dict_includes_explicit_account_type() -> None:
uc = UserCreate(
email="a@b.com", password="secret123", account_type=AccountType.SERVICE_ACCOUNT
)
d = uc.create_update_dict()
assert d["account_type"] == AccountType.STANDARD
def test_create_update_dict_superuser_includes_account_type() -> None:
uc = UserCreate(email="a@b.com", password="secret123")
d = uc.create_update_dict_superuser()
assert d["account_type"] == AccountType.STANDARD

View File

@@ -300,6 +300,66 @@ class TestExtractContextFiles:
assert result.file_texts == []
assert result.total_token_count == 50
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_tool_metadata_file_id_matches_chat_history_file_id(
self, mock_load: MagicMock
) -> None:
"""The file_id in tool metadata (from extract_context_files) and the
file_id in chat history messages (from build_file_context) must
agree, otherwise the LLM sees different IDs for the same file across
turns.
In production, UserFile.id (UUID PK) differs from UserFile.file_id
(file-store path). Both pathways should produce the same file_id
(UserFile.id) for FileReaderTool."""
from onyx.chat.chat_utils import build_file_context
user_file_uuid = uuid4()
file_store_path = f"user_files/{user_file_uuid}/data.csv"
uf = UserFile(
id=user_file_uuid,
file_id=file_store_path,
name="data.csv",
token_count=100,
file_type="text/csv",
)
in_memory = InMemoryChatFile(
file_id=file_store_path,
content=b"col1,col2\na,b",
file_type=ChatFileType.TABULAR,
filename="data.csv",
)
mock_load.return_value = [in_memory]
# Pathway 1: extract_context_files (project/persona context)
result = extract_context_files(
user_files=[uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert len(result.file_metadata_for_tool) == 1
tool_metadata_file_id = result.file_metadata_for_tool[0].file_id
# Pathway 2: build_file_context (chat history path)
# In convert_chat_history, tool_file_id comes from
# file_descriptor["user_file_id"], which is str(UserFile.id)
ctx = build_file_context(
tool_file_id=str(user_file_uuid),
filename="data.csv",
file_type=ChatFileType.TABULAR,
)
chat_history_file_id = ctx.tool_metadata.file_id
# Both pathways must produce the same ID for the LLM
assert tool_metadata_file_id == chat_history_file_id, (
f"File ID mismatch: extract_context_files uses '{tool_metadata_file_id}' "
f"but build_file_context uses '{chat_history_file_id}'."
)
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
def test_overflow_with_vector_db_disabled_provides_tool_metadata(self) -> None:
"""When vector DB is disabled, overflow produces FileToolMetadata."""
@@ -316,6 +376,128 @@ class TestExtractContextFiles:
assert len(result.file_metadata_for_tool) == 1
assert result.file_metadata_for_tool[0].filename == "bigfile.txt"
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_metadata_only_files_not_counted_in_aggregate_tokens(
self, mock_load: MagicMock
) -> None:
"""Metadata-only files (TABULAR) should not count toward the token budget."""
text_file_id = str(uuid4())
text_uf = _make_user_file(token_count=100, file_id=text_file_id)
# TABULAR file with large token count — should be excluded from aggregate
tabular_uf = _make_user_file(
token_count=50000, name="huge.xlsx", file_id=str(uuid4())
)
tabular_uf.file_type = (
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
mock_load.return_value = [
_make_in_memory_file(file_id=text_file_id, content="text content"),
InMemoryChatFile(
file_id=str(tabular_uf.id),
content=b"binary xlsx",
file_type=ChatFileType.TABULAR,
filename="huge.xlsx",
),
]
result = extract_context_files(
user_files=[text_uf, tabular_uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
# Text file fits (100 < 6000), so files should be loaded
assert result.file_texts == ["text content"]
# TABULAR file should appear as tool metadata, not in file_texts
assert len(result.file_metadata_for_tool) == 1
assert result.file_metadata_for_tool[0].filename == "huge.xlsx"
@patch("onyx.chat.process_message.load_in_memory_chat_files")
def test_metadata_only_files_loaded_as_tool_metadata(
self, mock_load: MagicMock
) -> None:
"""When files fit, metadata-only files appear in file_metadata_for_tool."""
text_file_id = str(uuid4())
tabular_file_id = str(uuid4())
text_uf = _make_user_file(token_count=100, file_id=text_file_id)
tabular_uf = _make_user_file(
token_count=500, name="data.csv", file_id=tabular_file_id
)
tabular_uf.file_type = "text/csv"
mock_load.return_value = [
_make_in_memory_file(file_id=text_file_id, content="hello"),
InMemoryChatFile(
file_id=tabular_file_id,
content=b"col1,col2\na,b",
file_type=ChatFileType.TABULAR,
filename="data.csv",
),
]
result = extract_context_files(
user_files=[text_uf, tabular_uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.file_texts == ["hello"]
assert len(result.file_metadata_for_tool) == 1
assert result.file_metadata_for_tool[0].filename == "data.csv"
# TABULAR should not appear in file_metadata (that's for citation)
assert all(m.filename != "data.csv" for m in result.file_metadata)
def test_overflow_with_vector_db_preserves_metadata_only_tool_metadata(
self,
) -> None:
"""When text files overflow with vector DB enabled, metadata-only files
should still be exposed via file_metadata_for_tool since they aren't
in the vector DB and would otherwise be inaccessible."""
text_uf = _make_user_file(token_count=7000, name="bigfile.txt")
tabular_uf = _make_user_file(token_count=500, name="data.xlsx")
tabular_uf.file_type = (
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
result = extract_context_files(
user_files=[text_uf, tabular_uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
# Text files overflow → search filter enabled
assert result.use_as_search_filter is True
assert result.file_texts == []
# TABULAR file should still be in tool metadata
assert len(result.file_metadata_for_tool) == 1
assert result.file_metadata_for_tool[0].filename == "data.xlsx"
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
def test_overflow_no_vector_db_includes_all_files_in_tool_metadata(self) -> None:
"""When vector DB is disabled and files overflow, all files
(both text and metadata-only) appear in file_metadata_for_tool."""
text_uf = _make_user_file(token_count=7000, name="bigfile.txt")
tabular_uf = _make_user_file(token_count=500, name="data.xlsx")
tabular_uf.file_type = (
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
result = extract_context_files(
user_files=[text_uf, tabular_uf],
llm_max_context_window=10000,
reserved_token_count=0,
db_session=MagicMock(),
)
assert result.use_as_search_filter is False
assert len(result.file_metadata_for_tool) == 2
filenames = {m.filename for m in result.file_metadata_for_tool}
assert filenames == {"bigfile.txt", "data.xlsx"}
# ===========================================================================
# Search filter + search_usage determination

View File

@@ -0,0 +1,173 @@
"""Unit tests for the Emitter class.
All tests use the streaming mode (merged_queue required). Emitter has a single
code path — no standalone bus.
"""
import queue
from onyx.chat.emitter import Emitter
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _placement(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
)
def _packet(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Packet:
"""Build a minimal valid packet with an OverallStop payload."""
return Packet(
placement=_placement(turn_index, tab_index, sub_turn_index),
obj=OverallStop(stop_reason="test"),
)
def _make_emitter(model_idx: int = 0) -> tuple["Emitter", "queue.Queue"]:
"""Return (emitter, queue) wired together."""
mq: queue.Queue = queue.Queue()
return Emitter(merged_queue=mq, model_idx=model_idx), mq
# ---------------------------------------------------------------------------
# Queue routing
# ---------------------------------------------------------------------------
class TestEmitterQueueRouting:
def test_emit_lands_on_merged_queue(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet())
assert not mq.empty()
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
emitter, mq = _make_emitter(model_idx=1)
emitter.emit(_packet())
item = mq.get_nowait()
assert isinstance(item, tuple)
assert len(item) == 2
def test_multiple_packets_delivered_fifo(self) -> None:
emitter, mq = _make_emitter()
p1 = _packet(turn_index=0)
p2 = _packet(turn_index=1)
emitter.emit(p1)
emitter.emit(p2)
_, t1 = mq.get_nowait()
_, t2 = mq.get_nowait()
assert t1.placement.turn_index == 0
assert t2.placement.turn_index == 1
# ---------------------------------------------------------------------------
# model_index tagging
# ---------------------------------------------------------------------------
class TestEmitterModelIndexTagging:
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
"""N=1: default model_idx=0, so packet gets model_index=0."""
emitter, mq = _make_emitter(model_idx=0)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 0
def test_model_idx_one_tags_packet(self) -> None:
emitter, mq = _make_emitter(model_idx=1)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 1
def test_model_idx_two_tags_packet(self) -> None:
"""Boundary: third model in a 3-model run."""
emitter, mq = _make_emitter(model_idx=2)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 2
# ---------------------------------------------------------------------------
# Queue key
# ---------------------------------------------------------------------------
class TestEmitterQueueKey:
def test_key_equals_model_idx(self) -> None:
"""Drain loop uses the key to route packets; it must match model_idx."""
emitter, mq = _make_emitter(model_idx=2)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 2
def test_n1_key_is_zero(self) -> None:
emitter, mq = _make_emitter(model_idx=0)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 0
# ---------------------------------------------------------------------------
# Placement field preservation
# ---------------------------------------------------------------------------
class TestEmitterPlacementPreservation:
def test_turn_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(turn_index=5))
_, tagged = mq.get_nowait()
assert tagged.placement.turn_index == 5
def test_tab_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(tab_index=3))
_, tagged = mq.get_nowait()
assert tagged.placement.tab_index == 3
def test_sub_turn_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(sub_turn_index=2))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index == 2
def test_sub_turn_index_none_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(sub_turn_index=None))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index is None
def test_packet_obj_is_not_modified(self) -> None:
"""The payload object must survive tagging untouched."""
emitter, mq = _make_emitter()
original_obj = OverallStop(stop_reason="sentinel")
pkt = Packet(placement=_placement(), obj=original_obj)
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert tagged.obj is original_obj
def test_different_obj_types_are_handled(self) -> None:
"""Any valid PacketObj type passes through correctly."""
emitter, mq = _make_emitter()
pkt = Packet(placement=_placement(), obj=ReasoningStart())
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert isinstance(tagged.obj, ReasoningStart)

View File

@@ -644,6 +644,92 @@ class TestConstructMessageHistory:
assert "Project file 0 content" in project_message.message
assert "Project file 1 content" in project_message.message
def test_file_metadata_for_tool_produces_message(self) -> None:
"""When context_files has file_metadata_for_tool, a metadata listing
message should be injected into the history."""
system_prompt = create_message("System", MessageType.SYSTEM, 10)
user_msg = create_message("Analyze the spreadsheet", MessageType.USER, 5)
context_files = ExtractedContextFiles(
file_texts=[],
image_files=[],
use_as_search_filter=False,
total_token_count=0,
file_metadata=[],
uncapped_token_count=0,
file_metadata_for_tool=[
FileToolMetadata(
file_id="xlsx-1",
filename="report.xlsx",
approx_char_count=100000,
),
],
)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=[user_msg],
reminder_message=None,
context_files=context_files,
available_tokens=1000,
token_counter=_simple_token_counter,
)
# Should have: system, tool_metadata_message, user
assert len(result) == 3
metadata_msg = result[1]
assert metadata_msg.message_type == MessageType.USER
assert "report.xlsx" in metadata_msg.message
assert "xlsx-1" in metadata_msg.message
def test_metadata_only_and_text_files_both_present(self) -> None:
"""When both text content and tool metadata are present, both messages
should appear in the history."""
system_prompt = create_message("System", MessageType.SYSTEM, 10)
user_msg = create_message("Summarize everything", MessageType.USER, 5)
context_files = ExtractedContextFiles(
file_texts=["Text file content here"],
image_files=[],
use_as_search_filter=False,
total_token_count=100,
file_metadata=[
ContextFileMetadata(
file_id="txt-1",
filename="notes.txt",
file_content="Text file content here",
),
],
uncapped_token_count=100,
file_metadata_for_tool=[
FileToolMetadata(
file_id="xlsx-1",
filename="data.xlsx",
approx_char_count=50000,
),
],
)
result = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=[user_msg],
reminder_message=None,
context_files=context_files,
available_tokens=2000,
token_counter=_simple_token_counter,
)
# Should have: system, context_files_message, tool_metadata_message, user
assert len(result) == 4
# Context files message (text content)
assert "documents" in result[1].message
assert "Text file content here" in result[1].message
# Tool metadata message
assert "data.xlsx" in result[2].message
assert result[3] == user_msg
def _simple_token_counter(text: str) -> int:
"""Approximate token counter for tests (~4 chars per token)."""

View File

@@ -0,0 +1,768 @@
"""Unit tests for multi-model streaming validation and DB helpers.
These are pure unit tests — no real database or LLM calls required.
The validation logic in handle_multi_model_stream fires before any external
calls, so we can trigger it with lightweight mocks.
"""
import time
from collections.abc import Generator
from typing import Any
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from onyx.chat.models import StreamingError
from onyx.configs.constants import MessageType
from onyx.db.chat import set_preferred_response
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.utils.variable_functionality import global_version
@pytest.fixture(autouse=True)
def _restore_ee_version() -> Generator[None, None, None]:
"""Reset EE global state after each test.
Importing onyx.chat.process_message triggers set_is_ee_based_on_env_variable()
(via the celery client import chain). Without this fixture, the EE flag stays
True for the rest of the session and breaks unrelated tests that mock Confluence
or other connectors and assume EE is disabled.
"""
original = global_version._is_ee
yield
global_version._is_ee = original
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_request(**kwargs: Any) -> SendMessageRequest:
defaults: dict[str, Any] = {
"message": "hello",
"chat_session_id": uuid4(),
}
defaults.update(kwargs)
return SendMessageRequest(**defaults)
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
return LLMOverride(model_provider=provider, model_version=version)
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
"""Return the first item yielded by handle_multi_model_stream."""
from onyx.chat.process_message import handle_multi_model_stream
user = MagicMock()
user.is_anonymous = False
user.email = "test@example.com"
db = MagicMock()
gen = handle_multi_model_stream(req, user, db, overrides)
return next(gen)
# ---------------------------------------------------------------------------
# handle_multi_model_stream — validation
# ---------------------------------------------------------------------------
class TestRunMultiModelStreamValidation:
def test_single_override_yields_error(self) -> None:
"""Exactly 1 override is not multi-model — yields StreamingError."""
req = _make_request()
result = _first_from_stream(req, [_make_override()])
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_four_overrides_yields_error(self) -> None:
"""4 overrides exceeds maximum — yields StreamingError."""
req = _make_request()
result = _first_from_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_zero_overrides_yields_error(self) -> None:
"""Empty override list yields StreamingError."""
req = _make_request()
result = _first_from_stream(req, [])
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_deep_research_yields_error(self) -> None:
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
req = _make_request(deep_research=True)
result = _first_from_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
assert isinstance(result, StreamingError)
assert "not supported" in result.error
def test_exactly_two_overrides_is_minimum(self) -> None:
"""Boundary: 1 override yields error, 2 overrides passes validation."""
req = _make_request()
# 1 override must yield a StreamingError
result = _first_from_stream(req, [_make_override()])
assert isinstance(
result, StreamingError
), "1 override should yield StreamingError"
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
# missing session, that's OK — validation itself passed)
try:
result2 = _first_from_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
if isinstance(result2, StreamingError) and "2-3" in result2.error:
pytest.fail(
f"2 overrides should pass validation, got StreamingError: {result2.error}"
)
except Exception:
pass # Any non-validation error means validation passed
# ---------------------------------------------------------------------------
# set_preferred_response — validation (mocked db)
# ---------------------------------------------------------------------------
class TestSetPreferredResponseValidation:
def test_user_message_not_found(self) -> None:
db = MagicMock()
db.get.return_value = None
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=999, preferred_assistant_message_id=1
)
def test_wrong_message_type(self) -> None:
"""Cannot set preferred response on a non-USER message."""
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.ASSISTANT # wrong type
db.get.return_value = user_msg
with pytest.raises(ValueError, match="not a user message"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_message_not_found(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
# First call returns user_msg, second call (for assistant) returns None
db.get.side_effect = [user_msg, None]
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_not_child_of_user(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 999 # different parent
db.get.side_effect = [user_msg, assistant_msg]
with pytest.raises(ValueError, match="not a child"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_valid_call_sets_preferred_response_id(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 1 # correct parent
db.get.side_effect = [user_msg, assistant_msg]
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
assert user_msg.preferred_response_id == 2
assert user_msg.latest_child_message_id == 2
# ---------------------------------------------------------------------------
# LLMOverride — display_name field
# ---------------------------------------------------------------------------
class TestLLMOverrideDisplayName:
def test_display_name_defaults_none(self) -> None:
override = LLMOverride(model_provider="openai", model_version="gpt-4")
assert override.display_name is None
def test_display_name_set(self) -> None:
override = LLMOverride(
model_provider="openai",
model_version="gpt-4",
display_name="GPT-4 Turbo",
)
assert override.display_name == "GPT-4 Turbo"
def test_display_name_serializes(self) -> None:
override = LLMOverride(
model_provider="anthropic",
model_version="claude-opus-4-6",
display_name="Claude Opus",
)
d = override.model_dump()
assert d["display_name"] == "Claude Opus"
# ---------------------------------------------------------------------------
# _run_models — drain loop behaviour
# ---------------------------------------------------------------------------
def _make_setup(n_models: int = 1) -> MagicMock:
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
setup = MagicMock()
setup.llms = [MagicMock() for _ in range(n_models)]
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
setup.check_is_connected = MagicMock(return_value=True)
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
setup.reserved_token_count = 100
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
# constructors inside _run_model — must be typed correctly for Pydantic.
setup.new_msg_req.deep_research = False
setup.new_msg_req.internal_search_filters = None
setup.new_msg_req.allowed_tool_ids = None
setup.new_msg_req.include_citations = True
setup.search_params.project_id_filter = None
setup.search_params.persona_id_filter = None
setup.bypass_acl = False
setup.slack_context = None
setup.available_files.user_file_ids = []
setup.available_files.chat_file_ids = []
setup.forced_tool_id = None
setup.simple_chat_history = []
setup.chat_session.id = uuid4()
setup.user_message.id = None
setup.custom_tool_additional_headers = None
setup.mcp_headers = None
return setup
def _run_models_collect(setup: MagicMock) -> list:
"""Drive _run_models to completion and return all yielded items."""
from onyx.chat.process_message import _run_models
return list(_run_models(setup, MagicMock(), MagicMock()))
class TestRunModels:
"""Tests for the _run_models worker-thread drain loop.
All external dependencies (LLM, DB, tools) are patched out. Worker threads
still run but return immediately since run_llm_loop is mocked.
"""
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
def emit_stop(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(
placement=Placement(turn_index=0),
obj=OverallStop(stop_reason="complete"),
)
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert len(stops) == 1
stop_obj = stops[0].obj
assert isinstance(stop_obj, OverallStop)
assert stop_obj.stop_reason == "complete"
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
def emit_one(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index == 0
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
def emit_one(**kwargs: Any) -> None:
# _model_idx is set by _run_model based on position in setup.llms
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=2))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 2
indices = {p.placement.model_index for p in reasoning}
assert indices == {0, 1}
def test_model_error_yields_streaming_error(self) -> None:
"""An exception inside a worker thread is surfaced as a StreamingError."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("intentional test failure")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
assert errors[0].error_code == "MODEL_ERROR"
assert "intentional test failure" in errors[0].error
def test_one_model_error_does_not_stop_other_models(self) -> None:
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
setup = _make_setup(n_models=2)
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
if kwargs["llm"] is setup.llms[0]:
raise RuntimeError("model 0 failed")
kwargs["emitter"].emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=fail_model_0_succeed_model_1,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(setup)
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index == 1
def test_cancellation_yields_user_cancelled_stop(self) -> None:
"""If check_is_connected returns False, drain loop emits user_cancelled."""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
setup = _make_setup(n_models=1)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(setup)
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert any(
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
for s in stops
)
def test_stop_button_calls_completion_for_all_models(self) -> None:
"""llm_loop_completion_handle must be called for all models when the stop button fires.
Regression test for the disconnect-cleanup bug: the old
run_chat_loop_with_state_containers always called completion_callback in
its finally block (even on disconnect) so the DB message was updated from
the TERMINATED placeholder to a partial answer. The new _run_models must
replicate this — otherwise the integration test
test_send_message_disconnect_and_cleanup fails because the message stays
as "Response was terminated prior to completion, try regenerating."
"""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3)
setup = _make_setup(n_models=2)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
# Must be called once per model, not zero times
assert mock_handle.call_count == 2
def test_completion_handle_called_for_each_successful_model(self) -> None:
"""llm_loop_completion_handle must be called once per model that succeeded."""
setup = _make_setup(n_models=2)
with (
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
assert mock_handle.call_count == 2
def test_completion_handle_not_called_for_failed_model(self) -> None:
"""llm_loop_completion_handle must be skipped for a model that raised."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("fail")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(_make_setup(n_models=1))
mock_handle.assert_not_called()
def test_http_disconnect_completion_via_generator_exit(self) -> None:
"""GeneratorExit from HTTP disconnect triggers worker self-completion.
When the HTTP client closes the connection, Starlette throws GeneratorExit
into the stream generator. The finally block sets drain_done (signalling
emitters to stop blocking) and calls executor.shutdown(wait=False) so the
server thread is never blocked. Worker threads detect drain_done.is_set()
after run_llm_loop completes and self-persist the result via
llm_loop_completion_handle using their own DB session.
This is the primary regression for test_send_message_disconnect_and_cleanup:
the integration test disconnects mid-stream and expects the DB message to be
updated from the TERMINATED placeholder to the real response.
"""
import threading
# Signals the worker to unblock from run_llm_loop after gen.close() returns.
# This guarantees drain_done is set BEFORE the worker returns from run_llm_loop,
# so the self-completion path (drain_done.is_set() check) is always taken.
disconnect_received = threading.Event()
# Set by the llm_loop_completion_handle mock when called.
completion_called = threading.Event()
def emit_then_complete(**kwargs: Any) -> None:
"""Emit one packet (to give the drain loop a yield point), then block
until the main thread signals that gen.close() has been called. This
ensures drain_done is set before we return so model_succeeded is checked
against a set drain_done — no race condition.
"""
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
disconnect_received.wait(timeout=5)
setup = _make_setup(n_models=1)
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
setup.check_is_connected = MagicMock(return_value=True)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=emit_then_complete,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle",
side_effect=lambda *_, **__: completion_called.set(),
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
from onyx.chat.process_message import _run_models
# cast to Generator so .close() is available; _run_models returns
# AnswerStream (= Iterator) but the actual object is always a generator.
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
# Advance to the first yielded packet — generator suspends at `yield item`.
first = next(gen)
assert isinstance(first, Packet)
# Simulate Starlette closing the stream on HTTP client disconnect.
# GeneratorExit is thrown at the `yield item` suspension point.
gen.close()
# Unblock the worker now that drain_done has been set by gen.close().
disconnect_received.set()
# Worker self-completes asynchronously (executor.shutdown(wait=False)).
# Wait here, inside the patch context, so that get_session_with_current_tenant
# and llm_loop_completion_handle mocks are still active when the worker calls them.
assert completion_called.wait(
timeout=5
), "worker must self-complete via drain_done within 5 seconds"
assert (
mock_handle.call_count == 1
), "completion handle must be called once for the successful model"
def test_b1_race_disconnect_handler_completes_already_finished_model(self) -> None:
"""B1 regression: model finishes BEFORE GeneratorExit fires.
The worker exits _run_model with drain_done.is_set()=False and skips
self-completion. When gen.close() fires afterward, the finally else-branch
must detect model_succeeded=True and call llm_loop_completion_handle itself.
Contrast with test_http_disconnect_completion_via_generator_exit, which
tests the opposite ordering (worker finishes AFTER disconnect).
"""
import threading
import time
completion_called = threading.Event()
def emit_and_return_immediately(**kwargs: Any) -> None:
# Emit one packet so the drain loop has something to yield, then return
# immediately — no blocking. The worker will be done in microseconds.
kwargs["emitter"].emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
setup = _make_setup(n_models=1)
setup.check_is_connected = MagicMock(return_value=True)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=emit_and_return_immediately,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle",
side_effect=lambda *_, **__: completion_called.set(),
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
from onyx.chat.process_message import _run_models
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
first = next(gen)
assert isinstance(first, Packet)
# Give the worker thread time to finish completely (emit + return +
# finally + self-completion check). It does almost no work, so 100 ms
# is far more than enough while still keeping the test fast.
time.sleep(0.1)
# Now close — worker is already done, so else-branch handles completion.
gen.close()
assert completion_called.wait(
timeout=5
), "disconnect handler must call completion for a model that already finished"
assert mock_handle.call_count == 1, "completion must be called exactly once"
def test_stop_button_does_not_call_completion_for_errored_model(self) -> None:
"""B2 regression: stop-button must NOT call completion for an errored model.
When model 0 raises an exception, its reserved ChatMessage must not be
saved with 'stopped by user' — that message is wrong for a model that
errored. llm_loop_completion_handle must only be called for non-errored
models when the stop button fires.
"""
def fail_model_0(**kwargs: Any) -> None:
if kwargs["llm"] is setup.llms[0]:
raise RuntimeError("model 0 errored")
# Model 1: run forever (stop button fires before it finishes)
time.sleep(10)
setup = _make_setup(n_models=2)
# Return False immediately so the stop-button path fires while model 1
# is still sleeping (model 0 has already errored by then).
setup.check_is_connected = lambda: False
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
# Completion must NOT be called for model 0 (it errored).
# It MAY be called for model 1 (still in-flight when stop fired).
for call in mock_handle.call_args_list:
assert (
call.kwargs.get("llm") is not setup.llms[0]
), "llm_loop_completion_handle must not be called for the errored model"
def test_external_state_container_used_for_model_zero(self) -> None:
"""When provided, external_state_container is used as state_containers[0]."""
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.process_message import _run_models
external = ChatStateContainer()
setup = _make_setup(n_models=1)
with (
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
list(
_run_models(
setup, MagicMock(), MagicMock(), external_state_container=external
)
)
# The state_container kwarg passed to run_llm_loop must be the external one
call_kwargs = mock_llm.call_args.kwargs
assert call_kwargs["state_container"] is external

View File

@@ -1,176 +0,0 @@
"""
Unit tests for assign_user_to_default_groups__no_commit in onyx.db.users.
Covers:
1. Standard/service-account users get assigned to the correct default group
2. BOT, EXT_PERM_USER, ANONYMOUS account types are skipped
3. Missing default group raises RuntimeError
4. Already-in-group is a no-op
5. IntegrityError race condition is handled gracefully
6. The function never commits the session
"""
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from sqlalchemy.exc import IntegrityError
from onyx.db.enums import AccountType
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.users import assign_user_to_default_groups__no_commit
def _mock_user(
account_type: AccountType = AccountType.STANDARD,
email: str = "test@example.com",
) -> MagicMock:
user = MagicMock()
user.id = uuid4()
user.email = email
user.account_type = account_type
return user
def _mock_group(name: str = "Basic", group_id: int = 1) -> MagicMock:
group = MagicMock()
group.id = group_id
group.name = name
group.is_default = True
return group
def _make_query_chain(first_return: object = None) -> MagicMock:
"""Returns a mock that supports .filter(...).filter(...).first() chaining."""
chain = MagicMock()
chain.filter.return_value = chain
chain.first.return_value = first_return
return chain
def _setup_db_session(
group_result: object = None,
membership_result: object = None,
) -> MagicMock:
"""Create a db_session mock that routes query(UserGroup) and query(User__UserGroup)."""
db_session = MagicMock()
group_chain = _make_query_chain(group_result)
membership_chain = _make_query_chain(membership_result)
def query_side_effect(model: type) -> MagicMock:
if model is UserGroup:
return group_chain
if model is User__UserGroup:
return membership_chain
return MagicMock()
db_session.query.side_effect = query_side_effect
return db_session
def test_standard_user_assigned_to_basic_group() -> None:
group = _mock_group("Basic")
db_session = _setup_db_session(group_result=group, membership_result=None)
savepoint = MagicMock()
db_session.begin_nested.return_value = savepoint
user = _mock_user(AccountType.STANDARD)
assign_user_to_default_groups__no_commit(db_session, user, is_admin=False)
db_session.add.assert_called_once()
added = db_session.add.call_args[0][0]
assert isinstance(added, User__UserGroup)
assert added.user_id == user.id
assert added.user_group_id == group.id
db_session.flush.assert_called_once()
def test_admin_user_assigned_to_admin_group() -> None:
group = _mock_group("Admin", group_id=2)
db_session = _setup_db_session(group_result=group, membership_result=None)
savepoint = MagicMock()
db_session.begin_nested.return_value = savepoint
user = _mock_user(AccountType.STANDARD)
assign_user_to_default_groups__no_commit(db_session, user, is_admin=True)
db_session.add.assert_called_once()
added = db_session.add.call_args[0][0]
assert isinstance(added, User__UserGroup)
assert added.user_group_id == group.id
@pytest.mark.parametrize(
"account_type",
[AccountType.BOT, AccountType.EXT_PERM_USER, AccountType.ANONYMOUS],
)
def test_excluded_account_types_skipped(account_type: AccountType) -> None:
db_session = MagicMock()
user = _mock_user(account_type)
assign_user_to_default_groups__no_commit(db_session, user)
db_session.query.assert_not_called()
db_session.add.assert_not_called()
def test_service_account_not_skipped() -> None:
group = _mock_group("Basic")
db_session = _setup_db_session(group_result=group, membership_result=None)
savepoint = MagicMock()
db_session.begin_nested.return_value = savepoint
user = _mock_user(AccountType.SERVICE_ACCOUNT)
assign_user_to_default_groups__no_commit(db_session, user, is_admin=False)
db_session.add.assert_called_once()
def test_missing_default_group_raises_error() -> None:
db_session = _setup_db_session(group_result=None)
user = _mock_user()
with pytest.raises(RuntimeError, match="Default group .* not found"):
assign_user_to_default_groups__no_commit(db_session, user)
def test_already_in_group_is_noop() -> None:
group = _mock_group("Basic")
existing_membership = MagicMock()
db_session = _setup_db_session(
group_result=group, membership_result=existing_membership
)
user = _mock_user()
assign_user_to_default_groups__no_commit(db_session, user)
db_session.add.assert_not_called()
db_session.begin_nested.assert_not_called()
def test_integrity_error_race_condition_handled() -> None:
group = _mock_group("Basic")
db_session = _setup_db_session(group_result=group, membership_result=None)
savepoint = MagicMock()
db_session.begin_nested.return_value = savepoint
db_session.flush.side_effect = IntegrityError(None, None, Exception("duplicate"))
user = _mock_user()
# Should not raise
assign_user_to_default_groups__no_commit(db_session, user)
savepoint.rollback.assert_called_once()
def test_no_commit_called_on_successful_assignment() -> None:
group = _mock_group("Basic")
db_session = _setup_db_session(group_result=group, membership_result=None)
savepoint = MagicMock()
db_session.begin_nested.return_value = savepoint
user = _mock_user()
assign_user_to_default_groups__no_commit(db_session, user)
db_session.commit.assert_not_called()

View File

@@ -3,7 +3,6 @@ from unittest.mock import MagicMock
from uuid import uuid4
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from onyx.server.models import FullUserSnapshot
from onyx.server.models import UserGroupInfo
@@ -26,7 +25,6 @@ def _mock_user(
user.updated_at = updated_at or datetime.datetime(
2025, 6, 15, tzinfo=datetime.timezone.utc
)
user.account_type = AccountType.STANDARD
return user

View File

@@ -1,6 +1,6 @@
"""Tests for memory tool streaming packet emissions."""
from queue import Queue
import queue
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -18,9 +18,13 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
@pytest.fixture
def emitter() -> Emitter:
bus: Queue = Queue()
return Emitter(bus)
def emitter_queue() -> queue.Queue:
return queue.Queue()
@pytest.fixture
def emitter(emitter_queue: queue.Queue) -> Emitter:
return Emitter(merged_queue=emitter_queue)
@pytest.fixture
@@ -53,24 +57,27 @@ class TestMemoryToolEmitStart:
def test_emit_start_emits_memory_tool_start_packet(
self,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
placement: Placement,
) -> None:
memory_tool.emit_start(placement)
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert isinstance(packet.obj, MemoryToolStart)
assert packet.placement == placement
assert packet.placement is not None
assert packet.placement.turn_index == placement.turn_index
assert packet.placement.tab_index == placement.tab_index
assert packet.placement.model_index == 0 # emitter stamps model_index=0
def test_emit_start_with_different_placement(
self,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
) -> None:
placement = Placement(turn_index=2, tab_index=1)
memory_tool.emit_start(placement)
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert packet.placement.turn_index == 2
assert packet.placement.tab_index == 1
@@ -81,7 +88,7 @@ class TestMemoryToolRun:
self,
mock_process: MagicMock,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
placement: Placement,
override_kwargs: MemoryToolOverrideKwargs,
) -> None:
@@ -93,21 +100,19 @@ class TestMemoryToolRun:
memory="User prefers Python",
)
# The delta packet should be in the queue
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert isinstance(packet.obj, MemoryToolDelta)
assert packet.obj.memory_text == "User prefers Python"
assert packet.obj.operation == "add"
assert packet.obj.memory_id is None
assert packet.obj.index is None
assert packet.placement == placement
@patch("onyx.tools.tool_implementations.memory.memory_tool.process_memory_update")
def test_run_emits_delta_for_update_operation(
self,
mock_process: MagicMock,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
placement: Placement,
override_kwargs: MemoryToolOverrideKwargs,
) -> None:
@@ -119,7 +124,7 @@ class TestMemoryToolRun:
memory="User prefers light mode",
)
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert isinstance(packet.obj, MemoryToolDelta)
assert packet.obj.memory_text == "User prefers light mode"
assert packet.obj.operation == "update"

View File

@@ -1,153 +0,0 @@
import { Form, Formik } from "formik";
import { toast } from "@/hooks/useToast";
import { createApiKey, updateApiKey } from "./lib";
import Modal from "@/refresh-components/Modal";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import Text from "@/refresh-components/texts/Text";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { FormikField } from "@/refresh-components/form/FormikField";
import { FormField } from "@/refresh-components/form/FormField";
import { USER_ROLE_LABELS, UserRole } from "@/lib/types";
import { APIKey } from "./types";
import { SvgKey } from "@opal/icons";
export interface OnyxApiKeyFormProps {
onClose: () => void;
onCreateApiKey: (apiKey: APIKey) => void;
apiKey?: APIKey;
}
export default function OnyxApiKeyForm({
onClose,
onCreateApiKey,
apiKey,
}: OnyxApiKeyFormProps) {
const isUpdate = apiKey !== undefined;
return (
<Modal open onOpenChange={onClose}>
<Modal.Content width="sm" height="lg">
<Modal.Header
icon={SvgKey}
title={isUpdate ? "Update API Key" : "Create a new API Key"}
onClose={onClose}
/>
<Formik
initialValues={{
name: apiKey?.api_key_name || "",
role: apiKey?.api_key_role || UserRole.BASIC.toString(),
}}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
// Prepare the payload with the UserRole
const payload = {
...values,
role: values.role as UserRole, // Assign the role directly as a UserRole type
};
let response;
if (isUpdate) {
response = await updateApiKey(apiKey.api_key_id, payload);
} else {
response = await createApiKey(payload);
}
formikHelpers.setSubmitting(false);
if (response.ok) {
toast.success(
isUpdate
? "Successfully updated API key!"
: "Successfully created API key!"
);
if (!isUpdate) {
onCreateApiKey(await response.json());
}
onClose();
} else {
const responseJson = await response.json();
const errorMsg = responseJson.detail || responseJson.message;
toast.error(
isUpdate
? `Error updating API key - ${errorMsg}`
: `Error creating API key - ${errorMsg}`
);
}
}}
>
{({ isSubmitting }) => (
<Form className="w-full overflow-visible">
<Modal.Body>
<Text as="p">
Choose a memorable name for your API key. This is optional and
can be added or changed later!
</Text>
<FormikField<string>
name="name"
render={(field, helper, _meta, state) => (
<FormField name="name" state={state} className="w-full">
<FormField.Label>Name (optional):</FormField.Label>
<FormField.Control>
<InputTypeIn
{...field}
placeholder=""
onClear={() => helper.setValue("")}
showClearButton={false}
/>
</FormField.Control>
</FormField>
)}
/>
<FormikField<string>
name="role"
render={(field, helper, _meta, state) => (
<FormField name="role" state={state} className="w-full">
<FormField.Label>Role:</FormField.Label>
<FormField.Control>
<InputSelect
value={field.value}
onValueChange={(value) => helper.setValue(value)}
>
<InputSelect.Trigger placeholder="Select a role" />
<InputSelect.Content>
<InputSelect.Item
value={UserRole.LIMITED.toString()}
>
{USER_ROLE_LABELS[UserRole.LIMITED]}
</InputSelect.Item>
<InputSelect.Item value={UserRole.BASIC.toString()}>
{USER_ROLE_LABELS[UserRole.BASIC]}
</InputSelect.Item>
<InputSelect.Item value={UserRole.ADMIN.toString()}>
{USER_ROLE_LABELS[UserRole.ADMIN]}
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</FormField.Control>
<FormField.Description>
Select the role for this API key. Limited has access to
simple public APIs. Basic has access to regular user
APIs. Admin has access to admin level APIs.
</FormField.Description>
</FormField>
)}
/>
</Modal.Body>
<Modal.Footer>
<Disabled disabled={isSubmitting}>
<Button type="submit">
{isUpdate ? "Update" : "Create"}
</Button>
</Disabled>
</Modal.Footer>
</Form>
)}
</Formik>
</Modal.Content>
</Modal>
);
}

View File

@@ -1,39 +0,0 @@
import { APIKeyArgs, APIKey } from "./types";
export const createApiKey = async (apiKeyArgs: APIKeyArgs) => {
return fetch("/api/admin/api-key", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(apiKeyArgs),
});
};
export const regenerateApiKey = async (apiKey: APIKey) => {
return fetch(`/api/admin/api-key/${apiKey.api_key_id}/regenerate`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
};
export const updateApiKey = async (
apiKeyId: number,
apiKeyArgs: APIKeyArgs
) => {
return fetch(`/api/admin/api-key/${apiKeyId}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(apiKeyArgs),
});
};
export const deleteApiKey = async (apiKeyId: number) => {
return fetch(`/api/admin/api-key/${apiKeyId}`, {
method: "DELETE",
});
};

View File

@@ -1,259 +0,0 @@
"use client";
import { ThreeDotsLoader } from "@/components/Loading";
import { errorHandlingFetcher } from "@/lib/fetcher";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ErrorCallout } from "@/components/ErrorCallout";
import useSWR, { mutate } from "swr";
import Separator from "@/refresh-components/Separator";
import {
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
Table,
} from "@/components/ui/table";
import Title from "@/components/ui/title";
import { toast } from "@/hooks/useToast";
import { useState } from "react";
import { DeleteButton } from "@/components/DeleteButton";
import Modal from "@/refresh-components/Modal";
import { Spinner } from "@/components/Spinner";
import { deleteApiKey, regenerateApiKey } from "@/app/admin/api-key/lib";
import OnyxApiKeyForm from "@/app/admin/api-key/OnyxApiKeyForm";
import {
APIKey,
DISCORD_SERVICE_API_KEY_NAME,
} from "@/app/admin/api-key/types";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { Button } from "@opal/components";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
import Text from "@/refresh-components/texts/Text";
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
import Message from "@/refresh-components/messages/Message";
import { useCloudSubscription } from "@/hooks/useCloudSubscription";
import { useBillingInformation } from "@/hooks/useBillingInformation";
import { BillingStatus, hasActiveSubscription } from "@/lib/billing/interfaces";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
const route = ADMIN_ROUTES.API_KEYS;
function Main() {
const {
data: apiKeys,
isLoading,
error,
} = useSWR<APIKey[]>("/api/admin/api-key", errorHandlingFetcher);
const canCreateKeys = useCloudSubscription();
const { data: billingData } = useBillingInformation();
const isTrialing =
billingData !== undefined &&
hasActiveSubscription(billingData) &&
billingData.status === BillingStatus.TRIALING;
const [fullApiKey, setFullApiKey] = useState<string | null>(null);
const [keyIsGenerating, setKeyIsGenerating] = useState(false);
const [showCreateUpdateForm, setShowCreateUpdateForm] = useState(false);
const [selectedApiKey, setSelectedApiKey] = useState<APIKey | undefined>();
const handleEdit = (apiKey: APIKey) => {
setSelectedApiKey(apiKey);
setShowCreateUpdateForm(true);
};
if (isLoading) {
return <ThreeDotsLoader />;
}
if (!apiKeys || error) {
return (
<ErrorCallout
errorTitle="Failed to fetch API Keys"
errorMsg={error?.info?.detail || error.toString()}
/>
);
}
// Filter out the discord service key from the displayed list
const filteredApiKeys = apiKeys.filter(
(key) => key.api_key_name !== DISCORD_SERVICE_API_KEY_NAME
);
const introSection = (
<div className="flex flex-col items-start gap-4">
{isTrialing && (
<Message
static
warning
close={false}
className="w-full"
text="Upgrade to a paid plan to create API keys."
description="Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
/>
)}
<Text as="p">
API Keys allow you to access Onyx APIs programmatically.
{canCreateKeys
? " Click the button below to generate a new API Key."
: ""}
</Text>
{canCreateKeys ? (
<CreateButton onClick={() => setShowCreateUpdateForm(true)}>
Create API Key
</CreateButton>
) : isTrialing ? (
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
) : null}
</div>
);
if (filteredApiKeys.length === 0) {
return (
<div>
{introSection}
{showCreateUpdateForm && (
<OnyxApiKeyForm
onCreateApiKey={(apiKey) => {
setFullApiKey(apiKey.api_key);
}}
onClose={() => {
setShowCreateUpdateForm(false);
setSelectedApiKey(undefined);
mutate("/api/admin/api-key");
}}
apiKey={selectedApiKey}
/>
)}
</div>
);
}
return (
<>
<Modal open={!!fullApiKey}>
<Modal.Content width="sm" height="sm">
<Modal.Header
title="New API Key"
icon={SvgKey}
onClose={() => setFullApiKey(null)}
description="Make sure you copy your new API key. You won't be able to see this key again."
/>
<Modal.Body>
<Text as="p" className="break-all flex-1">
{fullApiKey}
</Text>
<CopyIconButton getCopyText={() => fullApiKey!} />
</Modal.Body>
</Modal.Content>
</Modal>
{keyIsGenerating && <Spinner />}
{introSection}
{canCreateKeys && (
<>
<Separator />
<Title className="mt-6">Existing API Keys</Title>
<Table className="overflow-visible">
<TableHeader>
<TableRow>
<TableHead>Name</TableHead>
<TableHead>API Key</TableHead>
<TableHead>Role</TableHead>
<TableHead>Regenerate</TableHead>
<TableHead>Delete</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{filteredApiKeys.map((apiKey) => (
<TableRow key={apiKey.api_key_id}>
<TableCell>
<Button
prominence="internal"
onClick={() => handleEdit(apiKey)}
icon={SvgEdit}
>
{apiKey.api_key_name || "null"}
</Button>
</TableCell>
<TableCell className="max-w-64">
{apiKey.api_key_display}
</TableCell>
<TableCell className="max-w-64">
{apiKey.api_key_role.toUpperCase()}
</TableCell>
<TableCell>
<Button
prominence="internal"
icon={SvgRefreshCw}
onClick={async () => {
setKeyIsGenerating(true);
const response = await regenerateApiKey(apiKey);
setKeyIsGenerating(false);
if (!response.ok) {
const errorMsg = await response.text();
toast.error(
`Failed to regenerate API Key: ${errorMsg}`
);
return;
}
const newKey = (await response.json()) as APIKey;
setFullApiKey(newKey.api_key);
mutate("/api/admin/api-key");
}}
>
Refresh
</Button>
</TableCell>
<TableCell>
<DeleteButton
onClick={async () => {
const response = await deleteApiKey(apiKey.api_key_id);
if (!response.ok) {
const errorMsg = await response.text();
toast.error(`Failed to delete API Key: ${errorMsg}`);
return;
}
mutate("/api/admin/api-key");
}}
/>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
{showCreateUpdateForm && (
<OnyxApiKeyForm
onCreateApiKey={(apiKey) => {
setFullApiKey(apiKey.api_key);
}}
onClose={() => {
setShowCreateUpdateForm(false);
setSelectedApiKey(undefined);
mutate("/api/admin/api-key");
}}
apiKey={selectedApiKey}
/>
)}
</>
)}
</>
);
}
export default function Page() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header title={route.title} icon={route.icon} separator />
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -0,0 +1 @@
export { default } from "@/refresh-pages/admin/ServiceAccountsPage";

View File

@@ -182,7 +182,8 @@ export async function* sendMessage({
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
const data = await response.json().catch(() => ({}));
throw new Error(data.detail ?? `HTTP error! status: ${response.status}`);
}
yield* handleSSEStream<PacketType>(response, signal);

View File

@@ -4,7 +4,7 @@ import { useCallback } from "react";
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import { AccountType, UserStatus } from "@/lib/types";
import { UserStatus } from "@/lib/types";
import type { UserRole, InvitedUserSnapshot } from "@/lib/types";
import type {
UserRow,
@@ -19,7 +19,6 @@ interface FullUserSnapshot {
id: string;
email: string;
role: UserRole;
account_type: AccountType;
is_active: boolean;
password_configured: boolean;
personal_name: string | null;

View File

@@ -901,6 +901,11 @@ export default function useChatController({
});
}
}
// Surface FIFO errors (e.g. 429 before any packets arrive) so the
// catch block replaces the thinking placeholder with an error message.
if (stack.error) {
throw new Error(stack.error);
}
} catch (e: any) {
console.log("Error:", e);
const errorMsg = e.message;

View File

@@ -181,7 +181,7 @@ export const ADMIN_ROUTES = {
sidebarLabel: "Users",
},
API_KEYS: {
path: "/admin/api-key",
path: "/admin/service-accounts",
icon: SvgUserKey,
title: "Service Accounts",
sidebarLabel: "Service Accounts",

View File

@@ -52,14 +52,6 @@ export interface UserPersonalization {
user_preferences: string;
}
export enum AccountType {
STANDARD = "STANDARD",
BOT = "BOT",
EXT_PERM_USER = "EXT_PERM_USER",
SERVICE_ACCOUNT = "SERVICE_ACCOUNT",
ANONYMOUS = "ANONYMOUS",
}
export enum UserRole {
LIMITED = "limited",
BASIC = "basic",
@@ -487,7 +479,6 @@ export interface UserGroup {
personas: Persona[];
is_up_to_date: boolean;
is_up_for_deletion: boolean;
is_default: boolean;
}
export enum ValidSources {

View File

@@ -87,7 +87,7 @@ function CreateGroupPage() {
const headerActions = (
<Section flexDirection="row" gap={0.5} width="auto" height="auto">
<Button
prominence="tertiary"
prominence="secondary"
onClick={() => router.push("/admin/groups")}
>
Cancel
@@ -102,7 +102,7 @@ function CreateGroupPage() {
);
return (
<SettingsLayouts.Root width="sm">
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgUsers}
title="Create Group"

View File

@@ -287,7 +287,7 @@ function EditGroupPage({ groupId }: EditGroupPageProps) {
// 404 state
if (!isLoading && !error && !group) {
return (
<SettingsLayouts.Root width="sm">
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgUsers}
title="Group Not Found"
@@ -307,7 +307,7 @@ function EditGroupPage({ groupId }: EditGroupPageProps) {
const headerActions = (
<Section flexDirection="row" gap={0.5} width="auto" height="auto">
<Button
prominence="tertiary"
prominence="secondary"
onClick={() => router.push("/admin/groups")}
>
Cancel
@@ -328,7 +328,7 @@ function EditGroupPage({ groupId }: EditGroupPageProps) {
return (
<>
<SettingsLayouts.Root width="sm">
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgUsers}
title="Edit Group"

View File

@@ -4,16 +4,14 @@ import type { Route } from "next";
import { useState } from "react";
import { useRouter } from "next/navigation";
import useSWR from "swr";
import { SvgPlusCircle, SvgUsers } from "@opal/icons";
import { Button } from "@opal/components";
import { SvgUsers } from "@opal/icons";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import { errorHandlingFetcher } from "@/lib/fetcher";
import type { UserGroup } from "@/lib/types";
import { USER_GROUP_URL } from "./svc";
import GroupsList from "./GroupsList";
import { Section } from "@/layouts/general-layouts";
import AdminListHeader from "@/sections/admin/AdminListHeader";
import { IllustrationContent } from "@opal/layouts";
import SvgNoResult from "@opal/illustrations/no-result";
@@ -28,34 +26,22 @@ function GroupsPage() {
} = useSWR<UserGroup[]>(USER_GROUP_URL, errorHandlingFetcher);
return (
<SettingsLayouts.Root width="sm">
{/* This is the sticky header for the groups page. It is used to display
* the groups page title and search input when scrolling down.
*/}
<div
className="sticky top-0 z-settings-header bg-background-tint-01"
data-testid="groups-page-heading"
>
<SettingsLayouts.Root>
<div data-testid="groups-page-heading">
<SettingsLayouts.Header icon={SvgUsers} title="Groups" separator />
<Section flexDirection="row" padding={1}>
<InputTypeIn
placeholder="Search groups..."
variant="internal"
value={searchQuery}
leftSearchIcon
onChange={(e) => setSearchQuery(e.target.value)}
/>
<Button
icon={SvgPlusCircle}
onClick={() => router.push("/admin/groups/create" as Route)}
>
New Group
</Button>
</Section>
</div>
<SettingsLayouts.Body>
<AdminListHeader
hasItems={!isLoading && !error && (groups?.length ?? 0) > 0}
searchQuery={searchQuery}
onSearchQueryChange={setSearchQuery}
placeholder="Search groups..."
emptyStateText="Create groups to organize users and manage access."
onAction={() => router.push("/admin/groups/create" as Route)}
actionLabel="New Group"
/>
{isLoading && <SimpleLoader />}
{error && (

View File

@@ -1,8 +1,10 @@
import type { UserGroup } from "@/lib/types";
/** Whether this group is a system default group (Admin, Basic). */
/** Groups that are created by the system and cannot be deleted. */
export const BUILT_IN_GROUP_NAMES = ["Basic", "Admin"] as const;
export function isBuiltInGroup(group: UserGroup): boolean {
return group.is_default;
return (BUILT_IN_GROUP_NAMES as readonly string[]).includes(group.name);
}
/** Human-readable description for built-in groups. */

View File

@@ -0,0 +1,175 @@
"use client";
import { Form, Formik } from "formik";
import { toast } from "@/hooks/useToast";
import {
createApiKey,
updateApiKey,
} from "@/refresh-pages/admin/ServiceAccountsPage/svc";
import type { APIKey } from "@/refresh-pages/admin/ServiceAccountsPage/interfaces";
import Modal from "@/refresh-components/Modal";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { FormikField } from "@/refresh-components/form/FormikField";
import { Vertical as VerticalInput } from "@/layouts/input-layouts";
import { USER_ROLE_LABELS, UserRole } from "@/lib/types";
import { SvgKey, SvgLock, SvgUser, SvgUserManage } from "@opal/icons";
interface ApiKeyFormModalProps {
onClose: () => void;
onCreateApiKey: (apiKey: APIKey) => void;
apiKey?: APIKey;
}
export default function ApiKeyFormModal({
onClose,
onCreateApiKey,
apiKey,
}: ApiKeyFormModalProps) {
const isUpdate = apiKey !== undefined;
return (
<Modal open onOpenChange={onClose}>
<Modal.Content width="sm" height="lg">
<Modal.Header
icon={SvgKey}
title={isUpdate ? "Update Service Account" : "Create Service Account"}
description={
isUpdate
? undefined
: "Use service account API key to programmatically access Onyx API with user-level permissions. You can modify the account details later."
}
onClose={onClose}
/>
<Formik
initialValues={{
name: apiKey?.api_key_name || "",
role: apiKey?.api_key_role || UserRole.BASIC.toString(),
}}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
const payload = {
...values,
role: values.role as UserRole,
};
try {
let response;
if (isUpdate) {
response = await updateApiKey(apiKey.api_key_id, payload);
} else {
response = await createApiKey(payload);
}
if (response.ok) {
toast.success(
isUpdate
? "Successfully updated service account!"
: "Successfully created service account!"
);
if (!isUpdate) {
onCreateApiKey(await response.json());
}
onClose();
} else {
const responseJson = await response.json();
const errorMsg = responseJson.detail || responseJson.message;
toast.error(
isUpdate
? `Error updating service account - ${errorMsg}`
: `Error creating service account - ${errorMsg}`
);
}
} catch (e) {
toast.error(
e instanceof Error ? e.message : "An unexpected error occurred."
);
} finally {
formikHelpers.setSubmitting(false);
}
}}
>
{({ isSubmitting, values }) => (
<Form className="w-full overflow-visible">
<Modal.Body>
<VerticalInput
name="name"
title="Name"
nonInteractive
sizePreset="main-ui"
>
<FormikField<string>
name="name"
render={(field, helper) => (
<InputTypeIn
{...field}
placeholder="Enter a name"
onClear={() => helper.setValue("")}
showClearButton={false}
/>
)}
/>
</VerticalInput>
<VerticalInput
name="role"
title="Account Permissions"
nonInteractive
sizePreset="main-ui"
>
<FormikField<string>
name="role"
render={(field, helper) => (
<InputSelect
value={field.value}
onValueChange={(value) => helper.setValue(value)}
>
<InputSelect.Trigger placeholder="Select permissions" />
<InputSelect.Content>
<InputSelect.Item
value={UserRole.ADMIN.toString()}
icon={SvgUserManage}
description="Unrestricted admin access to all endpoints."
>
{USER_ROLE_LABELS[UserRole.ADMIN]}
</InputSelect.Item>
<InputSelect.Item
value={UserRole.BASIC.toString()}
icon={SvgUser}
description="Standard user-level access to non-admin endpoints."
>
{USER_ROLE_LABELS[UserRole.BASIC]}
</InputSelect.Item>
<InputSelect.Item
value={UserRole.LIMITED.toString()}
icon={SvgLock}
description="For agents: chat posting and read-only access to other endpoints."
>
{USER_ROLE_LABELS[UserRole.LIMITED]}
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
)}
/>
</VerticalInput>
</Modal.Body>
<Modal.Footer>
<Button prominence="secondary" type="button" onClick={onClose}>
Cancel
</Button>
<Disabled disabled={isSubmitting || !values.name.trim()}>
<Button type="submit">
{isUpdate ? "Update" : "Create Account"}
</Button>
</Disabled>
</Modal.Footer>
</Form>
)}
</Formik>
</Modal.Content>
</Modal>
);
}

View File

@@ -0,0 +1,461 @@
"use client";
import { useMemo, useState } from "react";
import useSWR, { mutate } from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import { toast } from "@/hooks/useToast";
import { Button, Text } from "@opal/components";
import { Content, IllustrationContent } from "@opal/layouts";
import SvgNoResult from "@opal/illustrations/no-result";
import {
SvgDownload,
SvgKey,
SvgLock,
SvgMoreHorizontal,
SvgRefreshCw,
SvgTrash,
SvgUser,
SvgUserEdit,
SvgUserKey,
SvgUserManage,
} from "@opal/icons";
import { USER_ROLE_LABELS, UserRole } from "@/lib/types";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import AdminListHeader from "@/sections/admin/AdminListHeader";
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
import Code from "@/refresh-components/Code";
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
import LineItem from "@/refresh-components/buttons/LineItem";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import { markdown } from "@opal/utils";
import Message from "@/refresh-components/messages/Message";
import { useBillingInformation } from "@/hooks/useBillingInformation";
import { BillingStatus, hasActiveSubscription } from "@/lib/billing/interfaces";
import {
deleteApiKey,
regenerateApiKey,
updateApiKey,
} from "@/refresh-pages/admin/ServiceAccountsPage/svc";
import type { APIKey } from "@/refresh-pages/admin/ServiceAccountsPage/interfaces";
import { DISCORD_SERVICE_API_KEY_NAME } from "@/refresh-pages/admin/ServiceAccountsPage/interfaces";
import ApiKeyFormModal from "@/refresh-pages/admin/ServiceAccountsPage/ApiKeyFormModal";
import { Table } from "@opal/components";
import { createTableColumns } from "@opal/components/table/columns";
import { Section } from "@/layouts/general-layouts";
const API_KEY_SWR_KEY = "/api/admin/api-key";
const route = ADMIN_ROUTES.API_KEYS;
const tc = createTableColumns<APIKey>();
// ---------------------------------------------------------------------------
// Page
// ---------------------------------------------------------------------------
export default function ServiceAccountsPage() {
const {
data: apiKeys,
isLoading,
error,
} = useSWR<APIKey[]>(API_KEY_SWR_KEY, errorHandlingFetcher);
const { data: billingData } = useBillingInformation();
const isTrialing =
billingData !== undefined &&
hasActiveSubscription(billingData) &&
billingData.status === BillingStatus.TRIALING;
const [fullApiKey, setFullApiKey] = useState<string | null>(null);
const [showCreateUpdateForm, setShowCreateUpdateForm] = useState(false);
const [selectedApiKey, setSelectedApiKey] = useState<APIKey | undefined>();
const [search, setSearch] = useState("");
const [regenerateTarget, setRegenerateTarget] = useState<APIKey | null>(null);
const [deleteTarget, setDeleteTarget] = useState<APIKey | null>(null);
const visibleApiKeys = (apiKeys ?? []).filter(
(key) => key.api_key_name !== DISCORD_SERVICE_API_KEY_NAME
);
const filteredApiKeys = visibleApiKeys.filter(
(key) =>
!search ||
(key.api_key_name ?? "").toLowerCase().includes(search.toLowerCase()) ||
key.api_key_display.toLowerCase().includes(search.toLowerCase())
);
const handleRoleChange = async (apiKey: APIKey, newRole: UserRole) => {
try {
const response = await updateApiKey(apiKey.api_key_id, {
name: apiKey.api_key_name ?? undefined,
role: newRole,
});
if (!response.ok) {
const errorMsg = await response.text();
toast.error(`Failed to update role: ${errorMsg}`);
return;
}
mutate(API_KEY_SWR_KEY);
toast.success("Role updated.");
} catch {
toast.error("Failed to update role.");
}
};
const handleRegenerate = async (apiKey: APIKey) => {
try {
const response = await regenerateApiKey(apiKey);
if (!response.ok) {
const errorMsg = await response.text();
toast.error(`Failed to regenerate API Key: ${errorMsg}`);
return;
}
const newKey = (await response.json()) as APIKey;
setFullApiKey(newKey.api_key);
mutate(API_KEY_SWR_KEY);
} catch (e) {
toast.error(
e instanceof Error ? e.message : "Failed to regenerate API Key."
);
}
};
const handleDelete = async (apiKey: APIKey) => {
try {
const response = await deleteApiKey(apiKey.api_key_id);
if (!response.ok) {
const errorMsg = await response.text();
toast.error(`Failed to delete API Key: ${errorMsg}`);
return;
}
mutate(API_KEY_SWR_KEY);
} catch (e) {
toast.error(e instanceof Error ? e.message : "Failed to delete API Key.");
}
};
const columns = useMemo(
() => [
tc.qualifier({
content: "icon",
getContent: () => SvgUserKey,
}),
tc.column("api_key_name", {
header: "Name",
weight: 25,
cell: (value) => (
<Content
title={value || "Unnamed"}
sizePreset="main-ui"
variant="body"
/>
),
}),
tc.column("api_key_display", {
header: "API Key",
weight: 30,
cell: (value) => (
<Text font="secondary-mono" color="text-03">
{value}
</Text>
),
}),
tc.displayColumn({
id: "account_type",
header: "Account Type",
width: { weight: 25, minWidth: 160 },
cell: (row) => (
<InputSelect
value={row.api_key_role}
onValueChange={(value) => handleRoleChange(row, value as UserRole)}
>
<InputSelect.Trigger />
<InputSelect.Content>
<InputSelect.Item
value={UserRole.ADMIN.toString()}
icon={SvgUserManage}
description="Unrestricted admin access to all endpoints."
>
{USER_ROLE_LABELS[UserRole.ADMIN]}
</InputSelect.Item>
<InputSelect.Item
value={UserRole.BASIC.toString()}
icon={SvgUser}
description="Standard user-level access to non-admin endpoints."
>
{USER_ROLE_LABELS[UserRole.BASIC]}
</InputSelect.Item>
<InputSelect.Item
value={UserRole.LIMITED.toString()}
icon={SvgLock}
description="For agents: chat posting and read-only access to other endpoints."
>
{USER_ROLE_LABELS[UserRole.LIMITED]}
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
),
}),
tc.actions({
cell: (row) => (
<div className="flex flex-row gap-1">
<Button
icon={SvgRefreshCw}
prominence="tertiary"
tooltip="Regenerate"
onClick={() => setRegenerateTarget(row)}
/>
<Popover>
<Popover.Trigger asChild>
<Button
icon={SvgMoreHorizontal}
prominence="tertiary"
tooltip="More"
/>
</Popover.Trigger>
<Popover.Content side="bottom" align="end" width="md">
<PopoverMenu>
<LineItem
icon={SvgUserEdit}
onClick={() => {
setSelectedApiKey(row);
setShowCreateUpdateForm(true);
}}
>
Edit Account
</LineItem>
<LineItem
icon={SvgTrash}
danger
onClick={() => setDeleteTarget(row)}
>
Delete Account
</LineItem>
</PopoverMenu>
</Popover.Content>
</Popover>
</div>
),
}),
],
[] // eslint-disable-line react-hooks/exhaustive-deps
);
if (error) {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
title={route.title}
icon={route.icon}
description="Use service accounts to programmatically access Onyx API."
separator
/>
<SettingsLayouts.Body>
<IllustrationContent
illustration={SvgNoResult}
title="Failed to load service accounts."
description="Please check the console for more details."
/>
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}
if (isLoading) {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
title={route.title}
icon={route.icon}
description="Use service accounts to programmatically access Onyx API."
separator
/>
<SettingsLayouts.Body>
<SimpleLoader />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}
const hasKeys = visibleApiKeys.length > 0;
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
title={route.title}
icon={route.icon}
description="Use service accounts to programmatically access Onyx API."
separator
/>
<SettingsLayouts.Body>
{isTrialing && (
<Message
static
warning
close={false}
className="w-full"
text="Upgrade to a paid plan to create API keys."
description="Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
/>
)}
<div className="flex flex-col">
<AdminListHeader
hasItems={hasKeys}
searchQuery={search}
onSearchQueryChange={setSearch}
placeholder="Search service accounts..."
emptyStateText="Create service account API keys with user-level access."
onAction={() => {
setSelectedApiKey(undefined);
setShowCreateUpdateForm(true);
}}
actionLabel="New Service Account"
/>
{hasKeys && (
<Table
data={filteredApiKeys}
getRowId={(row) => String(row.api_key_id)}
columns={columns}
searchTerm={search}
/>
)}
</div>
</SettingsLayouts.Body>
<Modal open={!!fullApiKey}>
<Modal.Content width="sm" height="sm">
<Modal.Header
title="Service Account API Key"
icon={SvgKey}
onClose={() => setFullApiKey(null)}
description="Save this key before continuing. It won't be shown again."
/>
<Modal.Body>
<Code showCopyButton={false}>{fullApiKey ?? ""}</Code>
</Modal.Body>
<Modal.Footer>
<BasicModalFooter
left={
<Button
prominence="secondary"
icon={SvgDownload}
onClick={() => {
if (!fullApiKey) return;
const blob = new Blob([fullApiKey], {
type: "text/plain",
});
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = "onyx-api-key.txt";
a.click();
URL.revokeObjectURL(url);
}}
>
Download
</Button>
}
submit={
// TODO(@raunakab): Create an opalified copy-button and replace it here
<Button
onClick={() => {
if (fullApiKey) {
navigator.clipboard.writeText(fullApiKey);
toast.success("API key copied to clipboard.");
}
}}
>
Copy API Key
</Button>
}
/>
</Modal.Footer>
</Modal.Content>
</Modal>
{showCreateUpdateForm && (
<ApiKeyFormModal
onCreateApiKey={(apiKey) => {
setFullApiKey(apiKey.api_key);
}}
onClose={() => {
setShowCreateUpdateForm(false);
setSelectedApiKey(undefined);
mutate(API_KEY_SWR_KEY);
}}
apiKey={selectedApiKey}
/>
)}
{regenerateTarget && (
<ConfirmationModalLayout
icon={SvgRefreshCw}
title="Regenerate API Key"
onClose={() => setRegenerateTarget(null)}
submit={
<Button
variant="danger"
onClick={async () => {
const target = regenerateTarget;
setRegenerateTarget(null);
await handleRegenerate(target);
}}
>
Regenerate Key
</Button>
}
>
<Text as="p" color="text-03">
{markdown(
`Your current API key *${
regenerateTarget.api_key_name || "Unnamed"
}* (\`${
regenerateTarget.api_key_display
}\`) will be revoked and a new key will be generated. You will need to update any applications using this key with the new one.`
)}
</Text>
</ConfirmationModalLayout>
)}
{deleteTarget && (
<ConfirmationModalLayout
icon={SvgTrash}
title="Delete Account"
onClose={() => setDeleteTarget(null)}
submit={
<Button
variant="danger"
onClick={async () => {
await handleDelete(deleteTarget);
setDeleteTarget(null);
}}
>
Delete
</Button>
}
>
<Section alignItems="start" gap={0.5}>
<Text as="p" color="text-03">
{markdown(
`Any application using the API key of account *${
deleteTarget.api_key_name || "Unnamed"
}* (\`${
deleteTarget.api_key_display
}\`) will lose access to Onyx.`
)}
</Text>
<Text as="p" color="text-03">
Deletion cannot be undone.
</Text>
</Section>
</ConfirmationModalLayout>
)}
</SettingsLayouts.Root>
);
}

View File

@@ -1,6 +1,5 @@
import { UserRole } from "@/lib/types";
// Discord bot service API key name - should match backend constant
export const DISCORD_SERVICE_API_KEY_NAME = "discord-bot-service";
export interface APIKey {

View File

@@ -0,0 +1,38 @@
import type {
APIKeyArgs,
APIKey,
} from "@/refresh-pages/admin/ServiceAccountsPage/interfaces";
const API_KEY_URL = "/api/admin/api-key";
export async function createApiKey(args: APIKeyArgs): Promise<Response> {
return fetch(API_KEY_URL, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(args),
});
}
export async function regenerateApiKey(apiKey: APIKey): Promise<Response> {
return fetch(`${API_KEY_URL}/${apiKey.api_key_id}/regenerate`, {
method: "POST",
headers: { "Content-Type": "application/json" },
});
}
export async function updateApiKey(
apiKeyId: number,
args: APIKeyArgs
): Promise<Response> {
return fetch(`${API_KEY_URL}/${apiKeyId}`, {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(args),
});
}
export async function deleteApiKey(apiKeyId: number): Promise<Response> {
return fetch(`${API_KEY_URL}/${apiKeyId}`, {
method: "DELETE",
});
}

View File

@@ -7,7 +7,7 @@ import {
IconProps,
OpenAIIcon,
} from "@/components/icons/icons";
import ProviderCard from "@/sections/cards/ProviderCard";
import ProviderCard from "@/sections/admin/ProviderCard";
import Message from "@/refresh-components/messages/Message";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { FetchError } from "@/lib/fetcher";

View File

@@ -1,69 +0,0 @@
"use client";
import React from "react";
import { cn } from "@/lib/utils";
import { Button } from "@opal/components";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import Text from "@/refresh-components/texts/Text";
import { SvgPlusCircle } from "@opal/icons";
interface ActionbarProps {
hasActions: boolean;
searchQuery?: string;
onSearchQueryChange?: (query: string) => void;
onAddAction: () => void;
className?: string;
buttonText?: string;
barText?: string;
}
const Actionbar: React.FC<ActionbarProps> = ({
hasActions,
searchQuery = "",
onSearchQueryChange,
onAddAction,
className,
buttonText = "Add MCP Server",
barText = "Connect MCP server to add custom actions.",
}) => {
const handleSearchChange = (e: React.ChangeEvent<HTMLInputElement>) => {
onSearchQueryChange?.(e.target.value);
};
return (
<div
className={cn(
"flex gap-4 items-center rounded-16",
!hasActions ? "bg-background-tint-00 border border-border-01 p-4" : "",
className
)}
>
{hasActions ? (
<div className="flex-1 min-w-[160px]">
<InputTypeIn
placeholder="Search servers…"
value={searchQuery}
onChange={handleSearchChange}
leftSearchIcon
showClearButton
className="w-full !bg-transparent !border-transparent [&:is(:hover,:active,:focus,:focus-within)]:!bg-background-neutral-00 [&:is(:hover,:active,:focus,:focus-within)]:!border-border-01 [&:is(:focus,:focus-within)]:!shadow-none"
/>
</div>
) : (
<div className="flex-1">
<Text as="p" mainUiMuted text03>
{barText}
</Text>
</div>
)}
<div className="flex gap-2 items-center justify-end">
<Button icon={SvgPlusCircle} onClick={onAddAction}>
{buttonText}
</Button>
</div>
</div>
);
};
Actionbar.displayName = "Actionbar";
export default Actionbar;

View File

@@ -3,7 +3,7 @@
import { useState, useCallback, useMemo, useEffect } from "react";
import { KeyedMutator } from "swr";
import MCPActionCard from "@/sections/actions/MCPActionCard";
import Actionbar from "@/sections/actions/Actionbar";
import AdminListHeader from "@/sections/admin/AdminListHeader";
import ActionCardSkeleton from "@/sections/actions/skeleton/ActionCardSkeleton";
import { getActionIcon } from "@/lib/tools/mcpUtils";
import {
@@ -487,13 +487,13 @@ export default function MCPPageContent() {
)}
<div className="flex-shrink-0 mb-4">
<Actionbar
hasActions={isLoading || mcpServers.length > 0}
<AdminListHeader
hasItems={isLoading || mcpServers.length > 0}
searchQuery={searchQuery}
onSearchQueryChange={setSearchQuery}
onAddAction={handleAddServer}
buttonText="Add MCP Server"
barText="Connect MCP server to add custom actions."
onAction={handleAddServer}
actionLabel="Add MCP Server"
emptyStateText="Connect MCP server to add custom actions."
/>
</div>

View File

@@ -8,7 +8,7 @@ import OpenAPIAuthenticationModal, {
OpenAPIAuthFormValues,
} from "./modals/OpenAPIAuthenticationModal";
import AddOpenAPIActionModal from "./modals/AddOpenAPIActionModal";
import Actionbar from "./Actionbar";
import AdminListHeader from "@/sections/admin/AdminListHeader";
import { toast } from "@/hooks/useToast";
import OpenApiActionCard from "./OpenApiActionCard";
import { createOAuthConfig, updateOAuthConfig } from "@/lib/oauth/api";
@@ -350,13 +350,13 @@ export default function OpenApiPageContent() {
)}
<div className="flex-shrink-0 mb-4">
<Actionbar
hasActions={isOpenApiLoading || (openApiTools?.length ?? 0) > 0}
<AdminListHeader
hasItems={isOpenApiLoading || (openApiTools?.length ?? 0) > 0}
searchQuery={searchQuery}
onSearchQueryChange={setSearchQuery}
onAddAction={handleAddAction}
buttonText="Add OpenAPI Action"
barText="Add custom actions from OpenAPI schemas."
onAction={handleAddAction}
actionLabel="Add OpenAPI Action"
emptyStateText="Add custom actions from OpenAPI schemas."
/>
</div>

View File

@@ -0,0 +1,98 @@
"use client";
import { Button, Card } from "@opal/components";
import { Content } from "@opal/layouts";
import { SvgPlusCircle } from "@opal/icons";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
interface AdminListHeaderProps {
/** Whether items exist — controls search bar vs empty-state card. */
hasItems: boolean;
/** Current search query. */
searchQuery: string;
/** Called when the search query changes. */
onSearchQueryChange: (query: string) => void;
/** Search input placeholder. */
placeholder?: string;
/** Text shown in the empty-state card when no items exist. */
emptyStateText: string;
/** Called when the action button is clicked. */
onAction: () => void;
/** Label for the action button. */
actionLabel: string;
}
/**
* AdminListHeader — the top bar for simple admin list pages.
*
* Handles two states:
*
* 1. **Items exist** (`hasItems = true`): renders a search input on the left
* with a primary action button on the right.
* 2. **No items** (`hasItems = false`): renders a bordered card with
* descriptive text on the left and the same action button on the right.
*
* The action button always renders with a `SvgPlusCircle` right icon.
*
* Used on admin pages that have a flat list of items with no advanced
* filtering — e.g. Service Accounts, Groups, OpenAPI Actions, MCP Servers.
*
* @example
* ```tsx
* <AdminListHeader
* hasItems={items.length > 0}
* searchQuery={search}
* onSearchQueryChange={setSearch}
* placeholder="Search service accounts..."
* emptyStateText="Create service account API keys with user-level access."
* onAction={handleCreate}
* actionLabel="New Service Account"
* />
* ```
*/
export default function AdminListHeader({
hasItems,
searchQuery,
onSearchQueryChange,
placeholder = "Search...",
emptyStateText,
onAction,
actionLabel,
}: AdminListHeaderProps) {
const actionButton = (
<Button rightIcon={SvgPlusCircle} onClick={onAction}>
{actionLabel}
</Button>
);
if (!hasItems) {
return (
<Card paddingVariant="md" roundingVariant="lg" borderVariant="solid">
<div className="flex flex-row items-center justify-between gap-3">
<Content
title={emptyStateText}
sizePreset="main-ui"
variant="body"
prominence="muted"
widthVariant="fit"
/>
{actionButton}
</div>
</Card>
);
}
return (
<div className="flex flex-row gap-3 items-center px-2 pb-3">
<InputTypeIn
variant="internal"
leftSearchIcon
placeholder={placeholder}
value={searchQuery}
onChange={(e) => onSearchQueryChange(e.target.value)}
showClearButton={false}
/>
{actionButton}
</div>
);
}

View File

@@ -11,6 +11,40 @@ import {
SvgUnplug,
} from "@opal/icons";
/**
* ProviderCard a stateful card for selecting / connecting / disconnecting
* an external service provider (LLM, search engine, voice model, etc.).
*
* Built on opal `SelectCard` + `CardHeaderLayout`. Maps a three-state
* status model to the `SelectCard` state system:
*
* | Status | SelectCard state | Right action |
* |----------------|------------------|------------------------|
* | `disconnected` | `empty` | "Connect" button |
* | `connected` | `filled` | "Set as Default" button|
* | `selected` | `selected` | "Current Default" label|
*
* Bottom-right actions (Disconnect, Edit) are always visible when the
* provider is connected or selected.
*
* Used on admin configuration pages: Web Search, Image Generation,
* Voice, and LLM Configuration.
*
* @example
* ```tsx
* <ProviderCard
* icon={SvgGlobe}
* title="Exa"
* description="Exa.ai"
* status="connected"
* onConnect={() => openModal()}
* onSelect={() => setDefault(id)}
* onEdit={() => openEditModal()}
* onDisconnect={() => confirmDisconnect(id)}
* />
* ```
*/
type ProviderStatus = "disconnected" | "connected" | "selected";
interface ProviderCardProps {

View File

@@ -53,18 +53,19 @@ test.describe("Groups page — layout", () => {
test.beforeAll(async ({ browser }) => {
await withApiContext(browser, async (api) => {
const groups = await api.getUserGroups();
const adminGroup = groups.find((g) => g.name === "Admin" && g.is_default);
const basicGroup = groups.find((g) => g.name === "Basic" && g.is_default);
if (!adminGroup || !basicGroup) {
throw new Error("Default Admin/Basic groups not found");
}
adminGroupId = adminGroup.id;
basicGroupId = basicGroup.id;
adminGroupId = await api.createUserGroup("Admin");
basicGroupId = await api.createUserGroup("Basic");
await api.waitForGroupSync(adminGroupId);
await api.waitForGroupSync(basicGroupId);
});
});
// No afterAll — these are built-in default groups and must not be deleted
test.afterAll(async ({ browser }) => {
await withApiContext(browser, async (api) => {
await softCleanup(() => api.deleteUserGroup(adminGroupId));
await softCleanup(() => api.deleteUserGroup(basicGroupId));
});
});
test("renders page title, search, and new group button", async ({
groupsPage,
@@ -76,8 +77,7 @@ test.describe("Groups page — layout", () => {
await expect(groupsPage.newGroupButton).toBeVisible();
});
test.skip("shows built-in groups (Admin, Basic)", async ({ groupsPage }) => {
// TODO: Enable once default groups are shown via include_default=true
test("shows built-in groups (Admin, Basic)", async ({ groupsPage }) => {
await groupsPage.goto();
await groupsPage.expectGroupVisible("Admin");

View File

@@ -632,18 +632,6 @@ export class OnyxApiClient {
this.log(`Deleted user group: ${groupId}`);
}
/**
* Lists all user groups.
*/
async getUserGroups(): Promise<
Array<{ id: number; name: string; is_default: boolean }>
> {
const response = await this.get(
"/manage/admin/user-group?include_default=true"
);
return response.json();
}
async setUserRole(
email: string,
role: "admin" | "curator" | "global_curator" | "basic",