mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-02 05:22:43 +00:00
Compare commits
13 Commits
seed-defau
...
multi-mode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
058f3d1403 | ||
|
|
6255a299b1 | ||
|
|
caa8811d61 | ||
|
|
7208d7ba8d | ||
|
|
23058c416d | ||
|
|
0874e0a5e6 | ||
|
|
165237faf4 | ||
|
|
19c3122fec | ||
|
|
00b228b357 | ||
|
|
59ae32f764 | ||
|
|
76000330ad | ||
|
|
eb6bd42c1e | ||
|
|
953cc28625 |
@@ -1,108 +0,0 @@
|
||||
"""backfill_account_type
|
||||
|
||||
Revision ID: 03d085c5c38d
|
||||
Revises: 977e834c1427
|
||||
Create Date: 2026-03-25 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03d085c5c38d"
|
||||
down_revision = "977e834c1427"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
_STANDARD = "STANDARD"
|
||||
_BOT = "BOT"
|
||||
_EXT_PERM_USER = "EXT_PERM_USER"
|
||||
_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
_ANONYMOUS = "ANONYMOUS"
|
||||
|
||||
# Well-known anonymous user UUID
|
||||
ANONYMOUS_USER_ID = "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
# Email pattern for API key virtual users
|
||||
API_KEY_EMAIL_PATTERN = r"API\_KEY\_\_%"
|
||||
|
||||
# Reflect the table structure for use in DML
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("email", sa.String),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Backfill account_type from role.
|
||||
# Order matters — most-specific matches first so the final catch-all
|
||||
# only touches rows that haven't been classified yet.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# 1a. API key virtual users → SERVICE_ACCOUNT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.email.ilike(API_KEY_EMAIL_PATTERN),
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_SERVICE_ACCOUNT)
|
||||
)
|
||||
|
||||
# 1b. Anonymous user → ANONYMOUS
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.id == ANONYMOUS_USER_ID,
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_ANONYMOUS)
|
||||
)
|
||||
|
||||
# 1c. SLACK_USER role → BOT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "SLACK_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_BOT)
|
||||
)
|
||||
|
||||
# 1d. EXT_PERM_USER role → EXT_PERM_USER
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "EXT_PERM_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_EXT_PERM_USER)
|
||||
)
|
||||
|
||||
# 1e. Everything else → STANDARD
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(user_table.c.account_type.is_(None))
|
||||
.values(account_type=_STANDARD)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Set account_type to NOT NULL now that every row is filled.
|
||||
# ------------------------------------------------------------------
|
||||
op.alter_column(
|
||||
"user",
|
||||
"account_type",
|
||||
nullable=False,
|
||||
server_default="STANDARD",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("user", "account_type", nullable=True, server_default=None)
|
||||
op.execute(sa.update(user_table).values(account_type=None))
|
||||
@@ -1,104 +0,0 @@
|
||||
"""add_effective_permissions
|
||||
|
||||
Adds a JSONB column `effective_permissions` to the user table to store
|
||||
directly granted permissions (e.g. ["admin"] or ["basic"]). Implied
|
||||
permissions are expanded at read time, not stored.
|
||||
|
||||
Backfill: joins user__user_group → permission_grant to collect each
|
||||
user's granted permissions into a JSON array. Users without group
|
||||
memberships keep the default [].
|
||||
|
||||
Revision ID: 503883791c39
|
||||
Revises: b4b7e1028dfd
|
||||
Create Date: 2026-03-30 14:49:22.261748
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "503883791c39"
|
||||
down_revision = "b4b7e1028dfd"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("effective_permissions", postgresql.JSONB),
|
||||
)
|
||||
|
||||
user_user_group = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_id", sa.Uuid),
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"effective_permissions",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'[]'::jsonb"),
|
||||
),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Deduplicated permissions per user
|
||||
deduped = (
|
||||
sa.select(
|
||||
user_user_group.c.user_id,
|
||||
permission_grant.c.permission,
|
||||
)
|
||||
.select_from(
|
||||
user_user_group.join(
|
||||
permission_grant,
|
||||
sa.and_(
|
||||
permission_grant.c.group_id == user_user_group.c.user_group_id,
|
||||
permission_grant.c.is_deleted == sa.false(),
|
||||
),
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
.subquery("deduped")
|
||||
)
|
||||
|
||||
# Aggregate into JSONB array per user (order is not guaranteed;
|
||||
# consumers read this as a set so ordering does not matter)
|
||||
perms_per_user = (
|
||||
sa.select(
|
||||
deduped.c.user_id,
|
||||
sa.func.jsonb_agg(
|
||||
deduped.c.permission,
|
||||
type_=postgresql.JSONB,
|
||||
).label("perms"),
|
||||
)
|
||||
.group_by(deduped.c.user_id)
|
||||
.subquery("sub")
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
user_table.update()
|
||||
.where(user_table.c.id == perms_per_user.c.user_id)
|
||||
.values(effective_permissions=perms_per_user.c.perms)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "effective_permissions")
|
||||
@@ -1,136 +0,0 @@
|
||||
"""seed_default_groups
|
||||
|
||||
Revision ID: 977e834c1427
|
||||
Revises: 8188861f4e92
|
||||
Create Date: 2026-03-25 14:59:41.313091
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "977e834c1427"
|
||||
down_revision = "8188861f4e92"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# (group_name, permission_value)
|
||||
DEFAULT_GROUPS = [
|
||||
("Admin", "admin"),
|
||||
("Basic", "basic"),
|
||||
]
|
||||
|
||||
CUSTOM_SUFFIX = "(Custom)"
|
||||
|
||||
MAX_RENAME_ATTEMPTS = 100
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_up_to_date", sa.Boolean),
|
||||
sa.column("is_up_for_deletion", sa.Boolean),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant_table = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def _find_available_name(conn: sa.engine.Connection, base: str) -> str:
|
||||
"""Return a name like 'Admin (Custom)' or 'Admin (Custom 2)' that is not taken."""
|
||||
candidate = f"{base} {CUSTOM_SUFFIX}"
|
||||
attempt = 1
|
||||
while attempt <= MAX_RENAME_ATTEMPTS:
|
||||
exists = conn.execute(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(user_group_table)
|
||||
.where(user_group_table.c.name == candidate)
|
||||
.limit(1)
|
||||
).fetchone()
|
||||
if exists is None:
|
||||
return candidate
|
||||
attempt += 1
|
||||
candidate = f"{base} (Custom {attempt})"
|
||||
raise RuntimeError(
|
||||
f"Could not find an available name for group '{base}' "
|
||||
f"after {MAX_RENAME_ATTEMPTS} attempts"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
for group_name, permission_value in DEFAULT_GROUPS:
|
||||
# Step 1: Rename ALL existing groups that clash with the canonical name.
|
||||
conflicting = conn.execute(
|
||||
sa.select(user_group_table.c.id, user_group_table.c.name).where(
|
||||
user_group_table.c.name == group_name
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for row_id, row_name in conflicting:
|
||||
new_name = _find_available_name(conn, row_name)
|
||||
op.execute(
|
||||
sa.update(user_group_table)
|
||||
.where(user_group_table.c.id == row_id)
|
||||
.values(name=new_name, is_up_to_date=False)
|
||||
)
|
||||
|
||||
# Step 2: Create a fresh default group.
|
||||
result = conn.execute(
|
||||
user_group_table.insert()
|
||||
.values(
|
||||
name=group_name,
|
||||
is_up_to_date=True,
|
||||
is_up_for_deletion=False,
|
||||
is_default=True,
|
||||
)
|
||||
.returning(user_group_table.c.id)
|
||||
).fetchone()
|
||||
assert result is not None
|
||||
group_id = result[0]
|
||||
|
||||
# Step 3: Upsert permission grant.
|
||||
op.execute(
|
||||
pg_insert(permission_grant_table)
|
||||
.values(
|
||||
group_id=group_id,
|
||||
permission=permission_value,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["group_id", "permission"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the default groups created by this migration.
|
||||
# First remove user-group memberships that reference default groups
|
||||
# to avoid FK violations, then delete the groups themselves.
|
||||
default_group_ids = sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
op.execute(
|
||||
sa.delete(user__user_group_table).where(
|
||||
user__user_group_table.c.user_group_id.in_(default_group_ids)
|
||||
)
|
||||
)
|
||||
op.execute(
|
||||
sa.delete(user_group_table).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
)
|
||||
@@ -1,84 +0,0 @@
|
||||
"""grant_basic_to_existing_groups
|
||||
|
||||
Grants the "basic" permission to all existing groups that don't already
|
||||
have it. Every group should have at least "basic" so that its members
|
||||
get basic access when effective_permissions is backfilled.
|
||||
|
||||
Revision ID: b4b7e1028dfd
|
||||
Revises: b7bcc991d722
|
||||
Create Date: 2026-03-30 16:15:17.093498
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b4b7e1028dfd"
|
||||
down_revision = "b7bcc991d722"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_group = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
already_has_basic = (
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(permission_grant)
|
||||
.where(
|
||||
permission_grant.c.group_id == user_group.c.id,
|
||||
permission_grant.c.permission == "basic",
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
|
||||
groups_needing_basic = sa.select(
|
||||
user_group.c.id,
|
||||
sa.literal("basic").label("permission"),
|
||||
sa.literal("SYSTEM").label("grant_source"),
|
||||
sa.literal(False).label("is_deleted"),
|
||||
).where(
|
||||
user_group.c.is_default == sa.false(),
|
||||
~already_has_basic,
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.insert().from_select(
|
||||
["group_id", "permission", "grant_source", "is_deleted"],
|
||||
groups_needing_basic,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
non_default_group_ids = sa.select(user_group.c.id).where(
|
||||
user_group.c.is_default == sa.false()
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.delete().where(
|
||||
permission_grant.c.permission == "basic",
|
||||
permission_grant.c.grant_source == "SYSTEM",
|
||||
permission_grant.c.group_id.in_(non_default_group_ids),
|
||||
)
|
||||
)
|
||||
@@ -1,116 +0,0 @@
|
||||
"""assign_users_to_default_groups
|
||||
|
||||
Revision ID: b7bcc991d722
|
||||
Revises: 03d085c5c38d
|
||||
Create Date: 2026-03-25 16:30:39.529301
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7bcc991d722"
|
||||
down_revision = "03d085c5c38d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
sa.column("is_active", sa.Boolean),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Look up default group IDs
|
||||
admin_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Admin",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
basic_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Basic",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
if admin_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Admin' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
if basic_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Basic' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
# Users with role=admin → Admin group
|
||||
# Exclude inactive placeholder/anonymous users that are not real users
|
||||
admin_users = sa.select(
|
||||
sa.literal(admin_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.role == "ADMIN",
|
||||
user_table.c.is_active == True, # noqa: E712
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], admin_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
# STANDARD users (non-admin) and SERVICE_ACCOUNT users (role=basic) → Basic group
|
||||
# Exclude inactive placeholder/anonymous users that are not real users
|
||||
basic_users = sa.select(
|
||||
sa.literal(basic_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.is_active == True, # noqa: E712
|
||||
sa.or_(
|
||||
sa.and_(
|
||||
user_table.c.account_type == "STANDARD",
|
||||
user_table.c.role != "ADMIN",
|
||||
),
|
||||
sa.and_(
|
||||
user_table.c.account_type == "SERVICE_ACCOUNT",
|
||||
user_table.c.role == "BASIC",
|
||||
),
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], basic_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Group memberships are left in place — removing them risks
|
||||
# deleting memberships that existed before this migration.
|
||||
pass
|
||||
@@ -19,8 +19,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
@@ -30,7 +28,6 @@ from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
@@ -39,7 +36,6 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -259,7 +255,6 @@ def fetch_user_groups(
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -274,7 +269,6 @@ def fetch_user_groups(
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
include_default: If False, excludes system default groups (is_default=True).
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -282,8 +276,6 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -294,7 +286,6 @@ def fetch_user_groups_for_user(
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -304,8 +295,6 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -489,16 +478,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
|
||||
# Every group gets the "basic" permission by default
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=db_user_group.id,
|
||||
permission=Permission.BASIC_ACCESS,
|
||||
grant_source=GrantSource.SYSTEM,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
_add_user__user_group_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=db_user_group.id,
|
||||
@@ -510,9 +489,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
cc_pair_ids=user_group.cc_pair_ids,
|
||||
)
|
||||
|
||||
for uid in user_group.user_ids:
|
||||
recompute_user_permissions__no_commit(uid, db_session)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -820,9 +796,6 @@ def update_user_group(
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
for uid in set(added_user_ids) | set(removed_user_ids):
|
||||
recompute_user_permissions__no_commit(uid, db_session)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -862,17 +835,6 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
# Collect affected user IDs before cleanup deletes the relationships
|
||||
affected_user_ids = (
|
||||
db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
@@ -901,11 +863,6 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
|
||||
# Recompute permissions for affected users now that their
|
||||
# membership in this group has been removed
|
||||
for uid in affected_user_ids:
|
||||
recompute_user_permissions__no_commit(uid, db_session)
|
||||
|
||||
db_user_group.is_up_to_date = False
|
||||
db_user_group.is_up_for_deletion = True
|
||||
db_session.commit()
|
||||
|
||||
@@ -52,13 +52,11 @@ from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
@@ -488,7 +486,6 @@ def create_user(
|
||||
email=email,
|
||||
hashed_password=_pw_helper.hash(_pw_helper.generate()),
|
||||
role=UserRole.BASIC,
|
||||
account_type=AccountType.STANDARD,
|
||||
is_active=user_resource.active,
|
||||
is_verified=True,
|
||||
personal_name=personal_name,
|
||||
@@ -509,25 +506,13 @@ def create_user(
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
# Assign user to default group BEFORE commit so everything is atomic.
|
||||
# If this fails, the entire user creation rolls back and IdP can retry.
|
||||
try:
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
except Exception:
|
||||
dal.rollback()
|
||||
logger.exception(f"Failed to assign SCIM user {email} to default groups")
|
||||
return _scim_error_response(
|
||||
500, f"Failed to assign user {email} to default group"
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
|
||||
@@ -43,16 +43,12 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
def list_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
@@ -60,50 +56,27 @@ def list_user_groups(
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@router.get("/user-groups/minimal")
|
||||
def list_minimal_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserGroupSnapshot]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
include_default=include_default,
|
||||
)
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [
|
||||
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
|
||||
]
|
||||
|
||||
|
||||
@router.get("/admin/user-group/{user_group_id}/permissions")
|
||||
def get_user_group_permissions(
|
||||
user_group_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[str]:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
|
||||
return [
|
||||
grant.permission.value
|
||||
for grant in group.permission_grants
|
||||
if not grant.is_deleted
|
||||
]
|
||||
|
||||
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
@@ -127,9 +100,6 @@ def rename_user_group_endpoint(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
group = fetch_user_group(db_session, rename_request.id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot rename a default system group.")
|
||||
try:
|
||||
return UserGroup.from_model(
|
||||
rename_user_group(
|
||||
@@ -215,9 +185,6 @@ def delete_user_group(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot delete a default system group.")
|
||||
try:
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -22,7 +22,6 @@ class UserGroup(BaseModel):
|
||||
personas: list[PersonaSnapshot]
|
||||
is_up_to_date: bool
|
||||
is_up_for_deletion: bool
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
|
||||
@@ -75,21 +74,18 @@ class UserGroup(BaseModel):
|
||||
],
|
||||
is_up_to_date=user_group_model.is_up_to_date,
|
||||
is_up_for_deletion=user_group_model.is_up_for_deletion,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
class MinimalUserGroupSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
|
||||
return cls(
|
||||
id=user_group_model.id,
|
||||
name=user_group_model.name,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
"""
|
||||
Permission resolution for group-based authorization.
|
||||
|
||||
Granted permissions are stored as a JSONB column on the User table and
|
||||
loaded for free with every auth query. Implied permissions are expanded
|
||||
at read time — only directly granted permissions are persisted.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ALL_PERMISSIONS: frozenset[str] = frozenset(p.value for p in Permission)
|
||||
|
||||
# Implication map: granted permission -> set of permissions it implies.
|
||||
IMPLIED_PERMISSIONS: dict[str, set[str]] = {
|
||||
Permission.ADD_AGENTS.value: {Permission.READ_AGENTS.value},
|
||||
Permission.MANAGE_AGENTS.value: {
|
||||
Permission.ADD_AGENTS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
},
|
||||
Permission.MANAGE_DOCUMENT_SETS.value: {
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value},
|
||||
Permission.MANAGE_CONNECTORS.value: {
|
||||
Permission.ADD_CONNECTORS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.MANAGE_USER_GROUPS.value: {
|
||||
Permission.READ_CONNECTORS.value,
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
Permission.READ_USERS.value,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def resolve_effective_permissions(granted: set[str]) -> set[str]:
|
||||
"""Expand granted permissions with their implied permissions.
|
||||
|
||||
If "admin" is present, returns all 19 permissions.
|
||||
"""
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS.value in granted:
|
||||
return set(ALL_PERMISSIONS)
|
||||
|
||||
effective = set(granted)
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for perm in list(effective):
|
||||
implied = IMPLIED_PERMISSIONS.get(perm)
|
||||
if implied and not implied.issubset(effective):
|
||||
effective |= implied
|
||||
changed = True
|
||||
return effective
|
||||
|
||||
|
||||
def get_effective_permissions(user: User) -> set[Permission]:
|
||||
"""Read granted permissions from the column and expand implied permissions."""
|
||||
granted: set[Permission] = set()
|
||||
for p in user.effective_permissions:
|
||||
try:
|
||||
granted.add(Permission(p))
|
||||
except ValueError:
|
||||
logger.warning(f"Skipping unknown permission '{p}' for user {user.id}")
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in granted:
|
||||
return set(Permission)
|
||||
expanded = resolve_effective_permissions({p.value for p in granted})
|
||||
return {Permission(p) for p in expanded}
|
||||
|
||||
|
||||
def require_permission(
|
||||
required: Permission,
|
||||
) -> Callable[..., Coroutine[Any, Any, User]]:
|
||||
"""FastAPI dependency factory for permission-based access control.
|
||||
|
||||
Usage:
|
||||
@router.get("/endpoint")
|
||||
def endpoint(user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS))):
|
||||
...
|
||||
"""
|
||||
|
||||
async def dependency(user: User = Depends(current_user)) -> User:
|
||||
effective = get_effective_permissions(user)
|
||||
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in effective:
|
||||
return user
|
||||
|
||||
if required not in effective:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"You do not have the required permissions for this action.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
return dependency
|
||||
@@ -5,8 +5,6 @@ from typing import Any
|
||||
from fastapi_users import schemas
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""
|
||||
@@ -43,7 +41,6 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
account_type: AccountType = AccountType.STANDARD
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
# Excluded from create_update_dict so it never reaches the DB layer
|
||||
@@ -53,16 +50,12 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
def create_update_dict(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict()
|
||||
d.pop("captcha_token", None)
|
||||
# Force STANDARD for self-registration; only trusted paths
|
||||
# (SCIM, API key creation) supply a different account_type directly.
|
||||
d["account_type"] = AccountType.STANDARD
|
||||
return d
|
||||
|
||||
@override
|
||||
def create_update_dict_superuser(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict_superuser()
|
||||
d.pop("captcha_token", None)
|
||||
d.setdefault("account_type", self.account_type)
|
||||
return d
|
||||
|
||||
|
||||
|
||||
@@ -120,13 +120,11 @@ from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
@@ -696,7 +694,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"email": account_email,
|
||||
"hashed_password": self.password_helper.hash(password),
|
||||
"is_verified": is_verified_by_default,
|
||||
"account_type": AccountType.STANDARD,
|
||||
}
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
@@ -746,23 +743,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
# Upgrade the user and assign default groups in a single
|
||||
# transaction so neither change is visible without the other.
|
||||
was_inactive = not user.is_active
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.is_verified = is_verified_by_default
|
||||
sync_user.role = UserRole.BASIC
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
if was_inactive:
|
||||
sync_user.is_active = True
|
||||
assign_user_to_default_groups__no_commit(sync_db, sync_user)
|
||||
sync_db.commit()
|
||||
|
||||
# Refresh the async user object so downstream code
|
||||
# (e.g. oidc_expiry check) sees the updated fields.
|
||||
user = await self.user_db.get(user.id) # type: ignore[arg-type]
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
**({"is_active": True} if not user.is_active else {}),
|
||||
},
|
||||
)
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
@@ -848,16 +836,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
event=MilestoneRecordType.TENANT_CREATED,
|
||||
)
|
||||
|
||||
# Assign user to the appropriate default group (Admin or Basic).
|
||||
# Must happen inside the try block while tenant context is active,
|
||||
# otherwise get_session_with_current_tenant() targets the wrong schema.
|
||||
is_admin = user_count == 1 or user.email in get_default_admin_user_emails()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=is_admin
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
@@ -1576,7 +1554,6 @@ def get_anonymous_user() -> User:
|
||||
is_verified=True,
|
||||
is_superuser=False,
|
||||
role=UserRole.LIMITED,
|
||||
account_type=AccountType.ANONYMOUS,
|
||||
use_memories=False,
|
||||
enable_memory_tool=False,
|
||||
)
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Type alias for search doc deduplication key
|
||||
# Simple key: just document_id (str)
|
||||
@@ -159,114 +148,3 @@ class ChatStateContainer:
|
||||
"""Thread-safe getter for emitted citations (returns a copy)."""
|
||||
with self._lock:
|
||||
return self._emitted_citations.copy()
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
with event streaming capabilities.
|
||||
|
||||
The wrapped function should accept emitter as first arg and use it to emit
|
||||
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
pass
|
||||
"""
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
# Run the function in a background thread
|
||||
thread = run_in_background(run_with_exception_capture)
|
||||
|
||||
pkt: Packet | None = None
|
||||
last_turn_index = 0 # Track the highest turn_index seen for stop packet
|
||||
last_cancel_check = time.monotonic()
|
||||
cancel_check_interval = 0.3 # Check for cancellation every 300ms
|
||||
try:
|
||||
while True:
|
||||
# Poll queue with 300ms timeout for natural stop signal checking
|
||||
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
|
||||
try:
|
||||
pkt = emitter.bus.get(timeout=0.3)
|
||||
except Empty:
|
||||
if not is_connected():
|
||||
# Stop signal detected
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = time.monotonic()
|
||||
continue
|
||||
|
||||
if pkt is not None:
|
||||
# Track the highest turn_index for the stop packet
|
||||
if pkt.placement and pkt.placement.turn_index > last_turn_index:
|
||||
last_turn_index = pkt.placement.turn_index
|
||||
|
||||
if isinstance(pkt.obj, OverallStop):
|
||||
yield pkt
|
||||
break
|
||||
elif isinstance(pkt.obj, PacketException):
|
||||
raise pkt.obj.exception
|
||||
else:
|
||||
yield pkt
|
||||
|
||||
# Check for cancellation periodically even when packets are flowing
|
||||
# This ensures stop signal is checked during active streaming
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_cancel_check >= cancel_check_interval:
|
||||
if not is_connected():
|
||||
# Stop signal detected during streaming
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = current_time
|
||||
finally:
|
||||
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
if is_connected():
|
||||
wait_on_background(thread)
|
||||
try:
|
||||
completion_callback(state_container)
|
||||
except Exception as e:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi.datastructures import Headers
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import ChatHistoryResult
|
||||
@@ -51,6 +52,60 @@ logger = setup_logger()
|
||||
IMAGE_GENERATION_TOOL_NAME = "generate_image"
|
||||
|
||||
|
||||
class FileContextResult(BaseModel):
|
||||
"""Result of building a file's LLM context representation."""
|
||||
|
||||
message: ChatMessageSimple
|
||||
tool_metadata: FileToolMetadata
|
||||
|
||||
|
||||
def build_file_context(
|
||||
tool_file_id: str,
|
||||
filename: str,
|
||||
file_type: ChatFileType,
|
||||
content_text: str | None = None,
|
||||
token_count: int = 0,
|
||||
approx_char_count: int | None = None,
|
||||
) -> FileContextResult:
|
||||
"""Build the LLM context representation for a single file.
|
||||
|
||||
Centralises how files should appear in the LLM prompt
|
||||
— the ID that FileReaderTool accepts (``UserFile.id`` for user files).
|
||||
"""
|
||||
if file_type.use_metadata_only():
|
||||
message_text = (
|
||||
f"File: {filename} (id={tool_file_id})\n"
|
||||
"Use the file_reader or python tools to access "
|
||||
"this file's contents."
|
||||
)
|
||||
message = ChatMessageSimple(
|
||||
message=message_text,
|
||||
token_count=max(1, len(message_text) // 4),
|
||||
message_type=MessageType.USER,
|
||||
file_id=tool_file_id,
|
||||
)
|
||||
else:
|
||||
message_text = f"File: {filename}\n{content_text or ''}\nEnd of File"
|
||||
message = ChatMessageSimple(
|
||||
message=message_text,
|
||||
token_count=token_count,
|
||||
message_type=MessageType.USER,
|
||||
file_id=tool_file_id,
|
||||
)
|
||||
|
||||
metadata = FileToolMetadata(
|
||||
file_id=tool_file_id,
|
||||
filename=filename,
|
||||
approx_char_count=(
|
||||
approx_char_count
|
||||
if approx_char_count is not None
|
||||
else len(content_text or "")
|
||||
),
|
||||
)
|
||||
|
||||
return FileContextResult(message=message, tool_metadata=metadata)
|
||||
|
||||
|
||||
def create_chat_session_from_request(
|
||||
chat_session_request: ChatSessionCreationRequest,
|
||||
user_id: UUID | None,
|
||||
@@ -538,7 +593,7 @@ def convert_chat_history(
|
||||
for idx, chat_message in enumerate(chat_history):
|
||||
if chat_message.message_type == MessageType.USER:
|
||||
# Process files attached to this message
|
||||
text_files: list[ChatLoadedFile] = []
|
||||
text_files: list[tuple[ChatLoadedFile, FileDescriptor]] = []
|
||||
image_files: list[ChatLoadedFile] = []
|
||||
|
||||
if chat_message.files:
|
||||
@@ -549,34 +604,26 @@ def convert_chat_history(
|
||||
if loaded_file.file_type == ChatFileType.IMAGE:
|
||||
image_files.append(loaded_file)
|
||||
else:
|
||||
# Text files (DOC, PLAIN_TEXT, CSV) are added as separate messages
|
||||
text_files.append(loaded_file)
|
||||
# Text files (DOC, PLAIN_TEXT, TABULAR) are added as separate messages
|
||||
text_files.append((loaded_file, file_descriptor))
|
||||
|
||||
# Add text files as separate messages before the user message.
|
||||
# Each message is tagged with ``file_id`` so that forgotten files
|
||||
# can be detected after context-window truncation.
|
||||
for text_file in text_files:
|
||||
file_text = text_file.content_text or ""
|
||||
filename = text_file.filename
|
||||
message = (
|
||||
f"File: {filename}\n{file_text}\nEnd of File"
|
||||
if filename
|
||||
else file_text
|
||||
)
|
||||
simple_messages.append(
|
||||
ChatMessageSimple(
|
||||
message=message,
|
||||
token_count=text_file.token_count,
|
||||
message_type=MessageType.USER,
|
||||
image_files=None,
|
||||
file_id=text_file.file_id,
|
||||
)
|
||||
)
|
||||
all_injected_file_metadata[text_file.file_id] = FileToolMetadata(
|
||||
file_id=text_file.file_id,
|
||||
filename=filename or "unknown",
|
||||
approx_char_count=len(file_text),
|
||||
for text_file, fd in text_files:
|
||||
# Use user_file_id as the FileReaderTool accepts that.
|
||||
# Fall back to the file-store path id.
|
||||
tool_id = fd.get("user_file_id") or text_file.file_id
|
||||
filename = text_file.filename or "unknown"
|
||||
ctx = build_file_context(
|
||||
tool_file_id=tool_id,
|
||||
filename=filename,
|
||||
file_type=text_file.file_type,
|
||||
content_text=text_file.content_text,
|
||||
token_count=text_file.token_count,
|
||||
)
|
||||
simple_messages.append(ctx.message)
|
||||
all_injected_file_metadata[tool_id] = ctx.tool_metadata
|
||||
|
||||
# Sum token counts from image files (excluding project image files)
|
||||
image_token_count = (
|
||||
|
||||
@@ -1,19 +1,40 @@
|
||||
import threading
|
||||
from queue import Queue
|
||||
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
class Emitter:
|
||||
"""Use this inside tools to emit arbitrary UI progress."""
|
||||
"""Routes packets from LLM/tool execution to the ``_run_models`` drain loop.
|
||||
|
||||
def __init__(self, bus: Queue):
|
||||
self.bus = bus
|
||||
Tags every packet with ``model_index`` and places it on ``merged_queue``
|
||||
as a ``(model_idx, packet)`` tuple for ordered consumption downstream.
|
||||
|
||||
Args:
|
||||
merged_queue: Shared queue owned by ``_run_models``.
|
||||
model_idx: Index embedded in packet placements (``0`` for N=1 runs).
|
||||
drain_done: Optional event set by ``_run_models`` when the drain loop
|
||||
exits early (e.g. HTTP disconnect). When set, ``emit`` returns
|
||||
immediately so worker threads can exit fast.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
merged_queue: Queue[tuple[int, Packet | Exception | object]],
|
||||
model_idx: int = 0,
|
||||
drain_done: threading.Event | None = None,
|
||||
) -> None:
|
||||
self._model_idx = model_idx
|
||||
self._merged_queue = merged_queue
|
||||
self._drain_done = drain_done
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
self.bus.put(packet) # Thread-safe
|
||||
|
||||
|
||||
def get_default_emitter() -> Emitter:
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
return emitter
|
||||
if self._drain_done is not None and self._drain_done.is_set():
|
||||
return
|
||||
base = packet.placement or Placement(turn_index=0)
|
||||
tagged = Packet(
|
||||
placement=base.model_copy(update={"model_index": self._model_idx}),
|
||||
obj=packet.obj,
|
||||
)
|
||||
self._merged_queue.put((self._model_idx, tagged))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -278,7 +278,6 @@ class NotificationType(str, Enum):
|
||||
RELEASE_NOTES = "release_notes"
|
||||
ASSISTANT_FILES_READY = "assistant_files_ready"
|
||||
FEATURE_ANNOUNCEMENT = "feature_announcement"
|
||||
USER_GROUP_ASSIGNMENT_FAILED = "user_group_assignment_failed"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
|
||||
@@ -11,19 +11,14 @@ from onyx.auth.api_key import ApiKeyDescriptor
|
||||
from onyx.auth.api_key import build_displayable_api_key
|
||||
from onyx.auth.api_key import generate_api_key
|
||||
from onyx.auth.api_key import hash_api_key
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_api_key_email_pattern() -> str:
|
||||
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
@@ -92,7 +87,6 @@ def insert_api_key(
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=api_key_args.role,
|
||||
account_type=AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
db_session.add(api_key_user_row)
|
||||
|
||||
@@ -105,21 +99,7 @@ def insert_api_key(
|
||||
)
|
||||
db_session.add(api_key_row)
|
||||
|
||||
# Assign the API key virtual user to the appropriate default group
|
||||
# before commit so everything is atomic.
|
||||
# LIMITED role service accounts should have no group membership.
|
||||
# Late import to avoid circular dependency (api_key <- users <- api_key).
|
||||
if api_key_args.role != UserRole.LIMITED:
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user_row,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
api_key_id=api_key_row.id,
|
||||
api_key_role=api_key_user_row.role,
|
||||
|
||||
@@ -8,6 +8,7 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
@@ -131,47 +132,32 @@ def get_chat_sessions_by_user(
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(ChatSession.project_id == project_id)
|
||||
elif only_non_project_chats:
|
||||
stmt = stmt.where(ChatSession.project_id.is_(None))
|
||||
|
||||
# When filtering out failed chats, we apply the limit in Python after
|
||||
# filtering rather than in SQL, since the post-filter may remove rows.
|
||||
if limit and include_failed_chats:
|
||||
stmt = stmt.limit(limit)
|
||||
if not include_failed_chats:
|
||||
non_system_message_exists_subq = (
|
||||
exists()
|
||||
.where(ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.correlate(ChatSession)
|
||||
)
|
||||
|
||||
# Leeway for newly created chats that don't have messages yet
|
||||
time = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
recently_created = ChatSession.time_created >= time
|
||||
|
||||
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = list(result.scalars().all())
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
if not include_failed_chats and chat_sessions:
|
||||
# Filter out "failed" sessions (those with only SYSTEM messages)
|
||||
# using a separate efficient query instead of a correlated EXISTS
|
||||
# subquery, which causes full sequential scans of chat_message.
|
||||
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
|
||||
|
||||
if session_ids:
|
||||
valid_session_ids_stmt = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.where(ChatMessage.chat_session_id.in_(session_ids))
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.distinct()
|
||||
)
|
||||
valid_session_ids = set(
|
||||
db_session.execute(valid_session_ids_stmt).scalars().all()
|
||||
)
|
||||
|
||||
chat_sessions = [
|
||||
cs
|
||||
for cs in chat_sessions
|
||||
if cs.time_created >= leeway or cs.id in valid_session_ids
|
||||
]
|
||||
|
||||
if limit:
|
||||
chat_sessions = chat_sessions[:limit]
|
||||
|
||||
return chat_sessions
|
||||
return list(chat_sessions)
|
||||
|
||||
|
||||
def delete_orphaned_search_docs(db_session: Session) -> None:
|
||||
@@ -631,6 +617,92 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
|
||||
|
||||
Also advances ``latest_child_message_id`` so the preferred response becomes
|
||||
the active branch for any subsequent messages in the conversation.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
|
||||
preferred response is being set.
|
||||
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
|
||||
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
|
||||
|
||||
Raises:
|
||||
ValueError: If either message is not found, if ``user_message_id`` does not
|
||||
refer to a USER message, or if the assistant message is not a direct child
|
||||
of the user message.
|
||||
"""
|
||||
user_msg = db_session.get(ChatMessage, user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child "
|
||||
f"of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
user_msg.latest_child_message_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -853,6 +925,8 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -13,19 +13,19 @@ class AccountType(str, PyEnum):
|
||||
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
|
||||
"""
|
||||
|
||||
STANDARD = "STANDARD"
|
||||
BOT = "BOT"
|
||||
EXT_PERM_USER = "EXT_PERM_USER"
|
||||
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
ANONYMOUS = "ANONYMOUS"
|
||||
STANDARD = "standard"
|
||||
BOT = "bot"
|
||||
EXT_PERM_USER = "ext_perm_user"
|
||||
SERVICE_ACCOUNT = "service_account"
|
||||
ANONYMOUS = "anonymous"
|
||||
|
||||
|
||||
class GrantSource(str, PyEnum):
|
||||
"""How a permission grant was created."""
|
||||
|
||||
USER = "USER"
|
||||
SCIM = "SCIM"
|
||||
SYSTEM = "SYSTEM"
|
||||
USER = "user"
|
||||
SCIM = "scim"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
|
||||
@@ -305,11 +305,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
account_type: Mapped[AccountType] = mapped_column(
|
||||
Enum(AccountType, native_enum=False),
|
||||
nullable=False,
|
||||
default=AccountType.STANDARD,
|
||||
server_default="STANDARD",
|
||||
account_type: Mapped[AccountType | None] = mapped_column(
|
||||
Enum(AccountType, native_enum=False), nullable=True
|
||||
)
|
||||
|
||||
"""
|
||||
@@ -356,13 +353,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
|
||||
effective_permissions: Mapped[list[str]] = mapped_column(
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
default=list,
|
||||
server_default=text("'[]'::jsonb"),
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
@@ -4026,12 +4016,7 @@ class PermissionGrant(Base):
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
permission: Mapped[Permission] = mapped_column(
|
||||
Enum(
|
||||
Permission,
|
||||
native_enum=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
Enum(Permission, native_enum=False), nullable=False
|
||||
)
|
||||
grant_source: Mapped[GrantSource] = mapped_column(
|
||||
Enum(GrantSource, native_enum=False), nullable=False
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
@@ -91,18 +90,9 @@ def get_notifications(
|
||||
notif_type: NotificationType | None = None,
|
||||
include_dismissed: bool = True,
|
||||
) -> list[Notification]:
|
||||
if user is None:
|
||||
user_filter = Notification.user_id.is_(None)
|
||||
elif user.role == UserRole.ADMIN:
|
||||
# Admins see their own notifications AND admin-targeted ones (user_id IS NULL)
|
||||
user_filter = or_(
|
||||
Notification.user_id == user.id,
|
||||
Notification.user_id.is_(None),
|
||||
)
|
||||
else:
|
||||
user_filter = Notification.user_id == user.id
|
||||
|
||||
query = select(Notification).where(user_filter)
|
||||
query = select(Notification).where(
|
||||
Notification.user_id == user.id if user else Notification.user_id.is_(None)
|
||||
)
|
||||
if not include_dismissed:
|
||||
query = query.where(Notification.dismissed.is_(False))
|
||||
if notif_type:
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
"""
|
||||
DB operations for recomputing user effective_permissions.
|
||||
|
||||
These live in onyx/db/ (not onyx/auth/) because they are pure DB operations
|
||||
that query PermissionGrant rows and update the User.effective_permissions
|
||||
JSONB column. Keeping them here avoids circular imports when called from
|
||||
other onyx/db/ modules such as users.py.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
|
||||
|
||||
def recompute_user_permissions__no_commit(user_id: UUID, db_session: Session) -> None:
|
||||
"""Recompute a single user's granted permissions from their group grants.
|
||||
|
||||
Stores only directly granted permissions — implication expansion
|
||||
happens at read time via get_effective_permissions().
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
stmt = (
|
||||
select(PermissionGrant.permission)
|
||||
.join(
|
||||
User__UserGroup,
|
||||
PermissionGrant.group_id == User__UserGroup.user_group_id,
|
||||
)
|
||||
.where(
|
||||
User__UserGroup.user_id == user_id,
|
||||
PermissionGrant.is_deleted.is_(False),
|
||||
)
|
||||
)
|
||||
rows = db_session.execute(stmt).scalars().all()
|
||||
# sorted for consistent ordering in DB — easier to read when debugging
|
||||
granted = sorted({p.value for p in rows})
|
||||
|
||||
db_session.execute(
|
||||
update(User).where(User.id == user_id).values(effective_permissions=granted)
|
||||
)
|
||||
|
||||
|
||||
def recompute_permissions_for_group__no_commit(
|
||||
group_id: int, db_session: Session
|
||||
) -> None:
|
||||
"""Recompute granted permissions for all users in a group.
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
user_ids: list[UUID] = list(
|
||||
db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == group_id,
|
||||
User__UserGroup.user_id.isnot(None),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
# Single query to fetch ALL permissions for these users across ALL their
|
||||
# groups (a user may belong to multiple groups with different grants).
|
||||
rows = db_session.execute(
|
||||
select(User__UserGroup.user_id, PermissionGrant.permission)
|
||||
.join(
|
||||
PermissionGrant,
|
||||
PermissionGrant.group_id == User__UserGroup.user_group_id,
|
||||
)
|
||||
.where(
|
||||
User__UserGroup.user_id.in_(user_ids),
|
||||
PermissionGrant.is_deleted.is_(False),
|
||||
)
|
||||
).all()
|
||||
|
||||
# Group permissions by user; users with no grants get an empty set.
|
||||
perms_by_user: dict[UUID, set[str]] = defaultdict(set)
|
||||
for uid in user_ids:
|
||||
perms_by_user[uid] # ensure every user has an entry
|
||||
for uid, perm in rows:
|
||||
perms_by_user[uid].add(perm.value)
|
||||
|
||||
for uid, perms in perms_by_user.items():
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == uid)
|
||||
.values(effective_permissions=sorted(perms))
|
||||
)
|
||||
@@ -19,7 +19,6 @@ from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
@@ -28,11 +27,8 @@ from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def validate_user_role_update(
|
||||
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
|
||||
@@ -302,7 +298,6 @@ def _generate_slack_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.SLACK_USER,
|
||||
account_type=AccountType.BOT,
|
||||
)
|
||||
|
||||
|
||||
@@ -313,7 +308,6 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
|
||||
# If the user is an external permissioned user, we update it to a slack user
|
||||
if user.role == UserRole.EXT_PERM_USER:
|
||||
user.role = UserRole.SLACK_USER
|
||||
user.account_type = AccountType.BOT
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
@@ -350,7 +344,6 @@ def _generate_ext_permissioned_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
@@ -382,81 +375,6 @@ def batch_add_ext_perm_user_if_not_exists(
|
||||
return all_users
|
||||
|
||||
|
||||
def assign_user_to_default_groups__no_commit(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
is_admin: bool = False,
|
||||
) -> None:
|
||||
"""Assign a newly created user to the appropriate default group.
|
||||
|
||||
Does NOT commit — callers must commit the session themselves so that
|
||||
group assignment can be part of the same transaction as user creation.
|
||||
|
||||
Args:
|
||||
is_admin: If True, assign to Admin default group; otherwise Basic.
|
||||
Callers determine this from their own context (e.g. user_count,
|
||||
admin email list, explicit choice). Defaults to False (Basic).
|
||||
"""
|
||||
if user.account_type in (
|
||||
AccountType.BOT,
|
||||
AccountType.EXT_PERM_USER,
|
||||
AccountType.ANONYMOUS,
|
||||
):
|
||||
return
|
||||
|
||||
target_group_name = "Admin" if is_admin else "Basic"
|
||||
|
||||
default_group = (
|
||||
db_session.query(UserGroup)
|
||||
.filter(
|
||||
UserGroup.name == target_group_name,
|
||||
UserGroup.is_default.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if default_group is None:
|
||||
raise RuntimeError(
|
||||
f"Default group '{target_group_name}' not found. "
|
||||
f"Cannot assign user {user.email} to a group. "
|
||||
f"Ensure the seed_default_groups migration has run."
|
||||
)
|
||||
|
||||
# Check if the user is already in the group
|
||||
existing = (
|
||||
db_session.query(User__UserGroup)
|
||||
.filter(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id == default_group.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing is not None:
|
||||
return
|
||||
|
||||
savepoint = db_session.begin_nested()
|
||||
try:
|
||||
db_session.add(
|
||||
User__UserGroup(
|
||||
user_id=user.id,
|
||||
user_group_id=default_group.id,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
except IntegrityError:
|
||||
# Race condition: another transaction inserted this membership
|
||||
# between our SELECT and INSERT. The savepoint isolates the failure
|
||||
# so the outer transaction (user creation) stays intact.
|
||||
savepoint.rollback()
|
||||
return
|
||||
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
logger.info(f"Assigned user {user.email} to default group '{default_group.name}'")
|
||||
|
||||
|
||||
def delete_user_from_db(
|
||||
user_to_delete: User,
|
||||
db_session: Session,
|
||||
@@ -503,14 +421,13 @@ def delete_user_from_db(
|
||||
def batch_get_user_groups(
|
||||
db_session: Session,
|
||||
user_ids: list[UUID],
|
||||
include_default: bool = False,
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Fetch group memberships for a batch of users in a single query.
|
||||
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
stmt = (
|
||||
rows = db_session.execute(
|
||||
select(
|
||||
User__UserGroup.user_id,
|
||||
UserGroup.id,
|
||||
@@ -518,11 +435,7 @@ def batch_get_user_groups(
|
||||
)
|
||||
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id.in_(user_ids))
|
||||
)
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
|
||||
rows = db_session.execute(stmt).all()
|
||||
).all()
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
|
||||
for user_id, group_id, group_name in rows:
|
||||
|
||||
@@ -23,6 +23,11 @@ class ChatFileType(str, Enum):
|
||||
ChatFileType.TABULAR,
|
||||
)
|
||||
|
||||
def use_metadata_only(self) -> bool:
|
||||
"""File types where we can ignore the file content
|
||||
and only use the metadata."""
|
||||
return self in (ChatFileType.TABULAR,)
|
||||
|
||||
|
||||
class FileDescriptor(TypedDict):
|
||||
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column
|
||||
|
||||
@@ -110,16 +110,20 @@ def load_user_file(file_id: UUID, db_session: Session) -> InMemoryChatFile:
|
||||
# check for plain text normalized version first, then use original file otherwise
|
||||
try:
|
||||
file_io = file_store.read_file(plaintext_file_name, mode="b")
|
||||
# For plaintext versions, use PLAIN_TEXT type (unless it's an image which doesn't have plaintext)
|
||||
plaintext_chat_file_type = (
|
||||
ChatFileType.PLAIN_TEXT
|
||||
if chat_file_type != ChatFileType.IMAGE
|
||||
else chat_file_type
|
||||
)
|
||||
|
||||
# if we have plaintext for image (which happens when image extraction is enabled), we use PLAIN_TEXT type
|
||||
if file_io is not None:
|
||||
# Metadata-only file types preserve their original type so
|
||||
# downstream injection paths can route them correctly.
|
||||
if chat_file_type.use_metadata_only():
|
||||
plaintext_chat_file_type = chat_file_type
|
||||
elif file_io is not None:
|
||||
# if we have plaintext for image (which happens when image
|
||||
# extraction is enabled), we use PLAIN_TEXT type
|
||||
plaintext_chat_file_type = ChatFileType.PLAIN_TEXT
|
||||
else:
|
||||
plaintext_chat_file_type = (
|
||||
ChatFileType.PLAIN_TEXT
|
||||
if chat_file_type != ChatFileType.IMAGE
|
||||
else chat_file_type
|
||||
)
|
||||
|
||||
chat_file = InMemoryChatFile(
|
||||
file_id=str(user_file.file_id),
|
||||
|
||||
@@ -8,6 +8,24 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMOverride(BaseModel):
|
||||
"""Per-request LLM settings that override persona defaults.
|
||||
|
||||
All fields are optional — only the fields that differ from the persona's
|
||||
configured LLM need to be supplied. Used both over the wire (API requests)
|
||||
and for multi-model comparison, where one override is supplied per model.
|
||||
|
||||
Attributes:
|
||||
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
|
||||
When ``None``, the persona's default provider is used.
|
||||
model_version: Specific model version string (e.g. ``"gpt-4o"``).
|
||||
When ``None``, the persona's default model is used.
|
||||
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
|
||||
persona's default temperature is used.
|
||||
display_name: Human-readable label shown in the UI for this model,
|
||||
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
|
||||
when not set.
|
||||
"""
|
||||
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
@@ -27,7 +27,6 @@ from onyx.auth.email_utils import send_user_email_invite
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -774,13 +773,6 @@ def _get_token_created_at(
|
||||
return get_current_token_creation_postgres(user, db_session)
|
||||
|
||||
|
||||
@router.get("/me/permissions", tags=PUBLIC_API_TAGS)
|
||||
def get_current_user_permissions(
|
||||
user: User = Depends(current_user),
|
||||
) -> list[str]:
|
||||
return sorted(p.value for p in get_effective_permissions(user))
|
||||
|
||||
|
||||
@router.get("/me", tags=PUBLIC_API_TAGS)
|
||||
def verify_user_logged_in(
|
||||
request: Request,
|
||||
|
||||
@@ -7,7 +7,6 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
@@ -42,7 +41,6 @@ class FullUserSnapshot(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
role: UserRole
|
||||
account_type: AccountType
|
||||
is_active: bool
|
||||
password_configured: bool
|
||||
personal_name: str | None
|
||||
@@ -62,7 +60,6 @@ class FullUserSnapshot(BaseModel):
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
account_type=user.account_type,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
personal_name=user.personal_name,
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.chat.chat_utils import extract_headers
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -570,6 +575,46 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
# Narrowed here; is_multi_model already checked llm_overrides is not None
|
||||
llm_overrides = chat_message_req.llm_overrides or []
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in handle_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
if is_multi_model and not chat_message_req.stream:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -660,6 +705,30 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
try:
|
||||
# Ownership check: get_chat_message raises ValueError if the message
|
||||
# doesn't belong to this user, preventing cross-user mutation.
|
||||
get_chat_message(
|
||||
chat_message_id=request_body.user_message_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
@@ -2,11 +2,25 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class Placement(BaseModel):
|
||||
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
|
||||
"""Coordinates that identify where a streaming packet belongs in the UI.
|
||||
|
||||
The frontend uses these fields to route each packet to the correct turn,
|
||||
tool tab, agent sub-turn, and (in multi-model mode) response column.
|
||||
|
||||
Attributes:
|
||||
turn_index: Monotonically increasing index of the iterative reasoning block
|
||||
(e.g. tool call round) within this chat message. Lower values happened first.
|
||||
tab_index: Disambiguates parallel tool calls within the same turn so each
|
||||
tool's output can be displayed in its own tab.
|
||||
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
|
||||
top-level packets; an integer for tool-within-tool output.
|
||||
model_index: Which model this packet belongs to. ``0`` for single-model
|
||||
responses; ``0``, ``1``, or ``2`` for multi-model comparison. ``None``
|
||||
for pre-LLM setup packets (e.g. message ID info) that are yielded
|
||||
before any Emitter runs.
|
||||
"""
|
||||
|
||||
turn_index: int
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
|
||||
model_index: int | None = None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
@@ -708,7 +709,6 @@ def run_research_agent_calls(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from queue import Queue
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
@@ -744,8 +744,8 @@ if __name__ == "__main__":
|
||||
if user is None:
|
||||
raise ValueError("No users found in database. Please create a user first.")
|
||||
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
emitter_queue: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(merged_queue=emitter_queue)
|
||||
state_container = ChatStateContainer()
|
||||
|
||||
tool_dict = construct_tools(
|
||||
@@ -792,4 +792,4 @@ if __name__ == "__main__":
|
||||
print(result.intermediate_report)
|
||||
print("=" * 80)
|
||||
print(f"Citations: {result.citation_mapping}")
|
||||
print(f"Total packets emitted: {bus.qsize()}")
|
||||
print(f"Total packets emitted: {emitter_queue.qsize()}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import csv
|
||||
import json
|
||||
import queue
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from io import StringIO
|
||||
@@ -11,7 +12,6 @@ import requests
|
||||
from requests import JSONDecodeError
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -296,9 +296,9 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
|
||||
# Use default emitter if none provided
|
||||
# Use a discard emitter if none provided (packets go nowhere)
|
||||
if emitter is None:
|
||||
emitter = get_default_emitter()
|
||||
emitter = Emitter(merged_queue=queue.Queue())
|
||||
|
||||
return [
|
||||
CustomTool(
|
||||
@@ -367,7 +367,7 @@ if __name__ == "__main__":
|
||||
tools = build_custom_tools_from_openapi_schema_and_headers(
|
||||
tool_id=0, # dummy tool id
|
||||
openapi_schema=openapi_schema,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
dynamic_schema_info=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,11 +27,13 @@ def create_placement(
|
||||
turn_index: int,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
model_index: int | None = 0,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
model_index=model_index,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -53,12 +52,7 @@ def tenant_context() -> Generator[None, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def create_test_user(
|
||||
db_session: Session,
|
||||
email_prefix: str,
|
||||
role: UserRole = UserRole.BASIC,
|
||||
account_type: AccountType = AccountType.STANDARD,
|
||||
) -> User:
|
||||
def create_test_user(db_session: Session, email_prefix: str) -> User:
|
||||
"""Helper to create a test user with a unique email"""
|
||||
# Use UUID to ensure unique email addresses
|
||||
unique_email = f"{email_prefix}_{uuid4().hex[:8]}@example.com"
|
||||
@@ -74,8 +68,7 @@ def create_test_user(
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=role,
|
||||
account_type=account_type,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -13,29 +13,16 @@ from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import PublicExternalUserGroup
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.models import UserRole
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
def _create_ext_perm_user(db_session: Session, name: str) -> User:
|
||||
"""Create an external-permission user for group sync tests."""
|
||||
return create_test_user(
|
||||
db_session,
|
||||
name,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
def _create_test_connector_credential_pair(
|
||||
db_session: Session, source: DocumentSource = DocumentSource.GOOGLE_DRIVE
|
||||
) -> ConnectorCredentialPair:
|
||||
@@ -113,9 +100,9 @@ class TestPerformExternalGroupSync:
|
||||
def test_initial_group_sync(self, db_session: Session) -> None:
|
||||
"""Test syncing external groups for the first time (initial sync)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user3 = _create_ext_perm_user(db_session, "user3")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user3 = create_test_user(db_session, "user3")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Mock external groups data as a generator that yields the expected groups
|
||||
@@ -188,9 +175,9 @@ class TestPerformExternalGroupSync:
|
||||
def test_update_existing_groups(self, db_session: Session) -> None:
|
||||
"""Test updating existing groups (adding/removing users)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user3 = _create_ext_perm_user(db_session, "user3")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user3 = create_test_user(db_session, "user3")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with original groups
|
||||
@@ -285,8 +272,8 @@ class TestPerformExternalGroupSync:
|
||||
def test_remove_groups(self, db_session: Session) -> None:
|
||||
"""Test removing groups (groups that no longer exist in external system)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with multiple groups
|
||||
@@ -370,7 +357,7 @@ class TestPerformExternalGroupSync:
|
||||
def test_empty_group_sync(self, db_session: Session) -> None:
|
||||
"""Test syncing when no groups are returned (all groups removed)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with groups
|
||||
@@ -426,7 +413,7 @@ class TestPerformExternalGroupSync:
|
||||
# Create many test users
|
||||
users = []
|
||||
for i in range(150): # More than the batch size of 100
|
||||
users.append(_create_ext_perm_user(db_session, f"user{i}"))
|
||||
users.append(create_test_user(db_session, f"user{i}"))
|
||||
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
@@ -465,8 +452,8 @@ class TestPerformExternalGroupSync:
|
||||
def test_mixed_regular_and_public_groups(self, db_session: Session) -> None:
|
||||
"""Test syncing a mix of regular and public groups"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
def mixed_group_sync_func(
|
||||
|
||||
@@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import BuildSessionStatus
|
||||
from onyx.db.models import BuildSession
|
||||
from onyx.db.models import User
|
||||
@@ -53,7 +52,6 @@ def test_user(db_session: Session, tenant_context: None) -> User: # noqa: ARG00
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
"""
|
||||
Tests that account_type is correctly set when creating users through
|
||||
the internal DB functions: add_slack_user_if_not_exists and
|
||||
batch_add_ext_perm_user_if_not_exists.
|
||||
|
||||
These functions are called by background workers (Slack bot, permission sync)
|
||||
and are not exposed via API endpoints, so they must be tested directly.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
|
||||
|
||||
def test_slack_user_creation_sets_account_type_bot(db_session: Session) -> None:
|
||||
"""add_slack_user_if_not_exists sets account_type=BOT and role=SLACK_USER."""
|
||||
user = add_slack_user_if_not_exists(db_session, "slack_acct_type@test.com")
|
||||
|
||||
assert user.role == UserRole.SLACK_USER
|
||||
assert user.account_type == AccountType.BOT
|
||||
|
||||
|
||||
def test_ext_perm_user_creation_sets_account_type(db_session: Session) -> None:
|
||||
"""batch_add_ext_perm_user_if_not_exists sets account_type=EXT_PERM_USER."""
|
||||
users = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session, ["extperm_acct_type@test.com"]
|
||||
)
|
||||
|
||||
assert len(users) == 1
|
||||
user = users[0]
|
||||
assert user.role == UserRole.EXT_PERM_USER
|
||||
assert user.account_type == AccountType.EXT_PERM_USER
|
||||
|
||||
|
||||
def test_ext_perm_to_slack_upgrade_updates_role_and_account_type(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""When an EXT_PERM_USER is upgraded to slack, both role and account_type update."""
|
||||
email = "ext_to_slack_acct_type@test.com"
|
||||
|
||||
# Create as ext_perm user first
|
||||
batch_add_ext_perm_user_if_not_exists(db_session, [email])
|
||||
|
||||
# Now "upgrade" via slack path
|
||||
user = add_slack_user_if_not_exists(db_session, email)
|
||||
|
||||
assert user.role == UserRole.SLACK_USER
|
||||
assert user.account_type == AccountType.BOT
|
||||
@@ -8,7 +8,6 @@ import pytest
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -47,7 +46,6 @@ def _create_admin(db_session: Session) -> User:
|
||||
is_superuser=True,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
account_type=AccountType.STANDARD,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -13,6 +13,7 @@ This test:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -20,7 +21,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
@@ -137,7 +138,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -200,7 +201,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -275,7 +276,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -350,7 +351,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -458,7 +459,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -541,7 +542,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
|
||||
@@ -8,6 +8,7 @@ Tests the priority logic for OAuth tokens when constructing custom tools:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
@@ -16,7 +17,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Persona
|
||||
@@ -174,7 +175,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -232,7 +233,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -284,7 +285,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -345,7 +346,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -416,7 +417,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -483,7 +484,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -536,7 +537,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
@@ -126,15 +126,6 @@ class UserManager:
|
||||
|
||||
return test_user
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(user: DATestUser) -> list[str]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/me/permissions",
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def is_role(
|
||||
user_to_verify: DATestUser,
|
||||
|
||||
@@ -104,30 +104,13 @@ class UserGroupManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(
|
||||
user_group: DATestUserGroup,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[str]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/permissions",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser,
|
||||
include_default: bool = False,
|
||||
) -> list[UserGroup]:
|
||||
params: dict[str, str] = {}
|
||||
if include_default:
|
||||
params["include_default"] = "true"
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
headers=user_performing_action.headers,
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [UserGroup(**ug) for ug in response.json()]
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -37,120 +33,3 @@ def test_limited(reset: None) -> None: # noqa: ARG001
|
||||
headers=api_key.headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def _get_service_account_account_type(
|
||||
admin_user: DATestUser,
|
||||
api_key_user_id: UUID,
|
||||
) -> AccountType:
|
||||
"""Fetch the account_type of a service account user via the user listing API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/users",
|
||||
headers=admin_user.headers,
|
||||
params={"include_api_keys": "true"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
user_id_str = str(api_key_user_id)
|
||||
for user in data["accepted"]:
|
||||
if user["id"] == user_id_str:
|
||||
return AccountType(user["account_type"])
|
||||
raise AssertionError(
|
||||
f"Service account user {user_id_str} not found in user listing"
|
||||
)
|
||||
|
||||
|
||||
def _get_default_group_user_ids(
|
||||
admin_user: DATestUser,
|
||||
) -> tuple[set[str], set[str]]:
|
||||
"""Return (admin_group_user_ids, basic_group_user_ids) from default groups."""
|
||||
all_groups = UserGroupManager.get_all(
|
||||
user_performing_action=admin_user,
|
||||
include_default=True,
|
||||
)
|
||||
admin_group = next(
|
||||
(g for g in all_groups if g.name == "Admin" and g.is_default), None
|
||||
)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.name == "Basic" and g.is_default), None
|
||||
)
|
||||
assert admin_group is not None, "Admin default group not found"
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
|
||||
admin_ids = {str(u.id) for u in admin_group.users}
|
||||
basic_ids = {str(u.id) for u in basic_group.users}
|
||||
return admin_ids, basic_ids
|
||||
|
||||
|
||||
def test_api_key_limited_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""LIMITED role API key: account_type is SERVICE_ACCOUNT, no group membership."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.LIMITED,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify no group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert (
|
||||
user_id_str not in admin_ids
|
||||
), "LIMITED API key should NOT be in Admin default group"
|
||||
assert (
|
||||
user_id_str not in basic_ids
|
||||
), "LIMITED API key should NOT be in Basic default group"
|
||||
|
||||
|
||||
def test_api_key_basic_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""BASIC role API key: account_type is SERVICE_ACCOUNT, in Basic group only."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.BASIC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify Basic group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert user_id_str in basic_ids, "BASIC API key should be in Basic default group"
|
||||
assert (
|
||||
user_id_str not in admin_ids
|
||||
), "BASIC API key should NOT be in Admin default group"
|
||||
|
||||
|
||||
def test_api_key_admin_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""ADMIN role API key: account_type is SERVICE_ACCOUNT, in Admin group only."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.ADMIN,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify Admin group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert user_id_str in admin_ids, "ADMIN API key should be in Admin default group"
|
||||
assert (
|
||||
user_id_str not in basic_ids
|
||||
), "ADMIN API key should NOT be in Basic default group"
|
||||
|
||||
@@ -4,10 +4,8 @@ import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@@ -97,63 +95,3 @@ def test_saml_user_conversion(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
# Verify the user's role was changed in the database
|
||||
assert UserManager.is_role(slack_user, UserRole.BASIC)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_user_conversion_sets_account_type_and_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that SAML login sets account_type to STANDARD when converting a
|
||||
non-web user (EXT_PERM_USER) and that the user receives the correct role
|
||||
(BASIC) after conversion.
|
||||
|
||||
This validates the permissions-migration-phase2 changes which ensure that:
|
||||
1. account_type is updated to 'standard' on SAML conversion
|
||||
2. The converted user is assigned to the Basic default group
|
||||
"""
|
||||
# Create an admin user (first user is automatically admin)
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# Create a user and set them as EXT_PERM_USER
|
||||
test_email = "ext_convert@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
|
||||
|
||||
# Simulate SAML login
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": test_email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_data = response.json()
|
||||
|
||||
# Verify account_type is set to standard after conversion
|
||||
assert (
|
||||
user_data["account_type"] == AccountType.STANDARD.value
|
||||
), f"Expected account_type='{AccountType.STANDARD.value}', got '{user_data['account_type']}'"
|
||||
|
||||
# Verify role is BASIC after conversion
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify the user was assigned to the Basic default group
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
basic_default = [g for g in all_groups if g.is_default and g.name == "Basic"]
|
||||
assert basic_default, "Basic default group not found"
|
||||
|
||||
basic_group = basic_default[0]
|
||||
member_emails = {u.email for u in basic_group.users}
|
||||
assert test_email in member_emails, (
|
||||
f"Converted user '{test_email}' not found in Basic default group members: "
|
||||
f"{member_emails}"
|
||||
)
|
||||
|
||||
@@ -35,16 +35,9 @@ from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
@@ -218,49 +211,6 @@ def test_create_user(scim_token: str, idp_style: str) -> None:
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_create_user_default_group_and_account_type(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""SCIM-provisioned users get Basic default group and STANDARD account_type."""
|
||||
email = f"scim_defaults_{idp_style}@example.com"
|
||||
ext_id = f"ext-defaults-{idp_style}"
|
||||
resp = _create_scim_user(scim_token, email, ext_id, idp_style)
|
||||
assert resp.status_code == 201
|
||||
user_id = resp.json()["id"]
|
||||
|
||||
# --- Verify group assignment via SCIM GET ---
|
||||
get_resp = ScimClient.get(f"/Users/{user_id}", scim_token)
|
||||
assert get_resp.status_code == 200
|
||||
groups = get_resp.json().get("groups", [])
|
||||
group_names = {g["display"] for g in groups}
|
||||
assert "Basic" in group_names, f"Expected 'Basic' in groups, got {group_names}"
|
||||
assert "Admin" not in group_names, "SCIM user should not be in Admin group"
|
||||
|
||||
# --- Verify account_type via admin API ---
|
||||
admin = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
page = UserManager.get_user_page(
|
||||
user_performing_action=admin,
|
||||
search_query=email,
|
||||
)
|
||||
assert page.total_items >= 1
|
||||
scim_user_snapshot = next((u for u in page.items if u.email == email), None)
|
||||
assert (
|
||||
scim_user_snapshot is not None
|
||||
), f"SCIM user {email} not found in user listing"
|
||||
assert (
|
||||
scim_user_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Expected STANDARD, got {scim_user_snapshot.account_type}"
|
||||
|
||||
|
||||
def test_get_user(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users/{id} returns the user resource with all stored fields."""
|
||||
email = f"scim_get_{idp_style}@example.com"
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import UserGroup as UserGroupModel
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_user_gets_permissions_when_added_to_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_for_perm_test")
|
||||
basic_user: DATestUser = UserManager.create(name="basic_user_for_perm_test")
|
||||
|
||||
# basic_user starts with only "basic" from the default group
|
||||
initial_permissions = UserManager.get_permissions(basic_user)
|
||||
assert "basic" in initial_permissions
|
||||
assert "add:agents" not in initial_permissions
|
||||
|
||||
# Create a new group and add basic_user
|
||||
group = UserGroupManager.create(
|
||||
name="perm-test-group",
|
||||
user_ids=[admin_user.id, basic_user.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Grant a non-basic permission to the group and recompute
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_group = db_session.get(UserGroupModel, group.id)
|
||||
assert db_group is not None
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=db_group.id,
|
||||
permission=Permission.ADD_AGENTS,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
recompute_user_permissions__no_commit(basic_user.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
# Verify the user gained the new permission (expanded includes read:agents)
|
||||
updated_permissions = UserManager.get_permissions(basic_user)
|
||||
assert (
|
||||
"add:agents" in updated_permissions
|
||||
), f"User should have 'add:agents' after group grant, got: {updated_permissions}"
|
||||
assert (
|
||||
"read:agents" in updated_permissions
|
||||
), f"User should have implied 'read:agents', got: {updated_permissions}"
|
||||
assert "basic" in updated_permissions
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_group_permission_change_propagates_to_all_members(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_propagate")
|
||||
user_a: DATestUser = UserManager.create(name="user_a_propagate")
|
||||
user_b: DATestUser = UserManager.create(name="user_b_propagate")
|
||||
|
||||
group = UserGroupManager.create(
|
||||
name="propagate-test-group",
|
||||
user_ids=[admin_user.id, user_a.id, user_b.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Neither user should have add:agents yet
|
||||
for u in (user_a, user_b):
|
||||
assert "add:agents" not in UserManager.get_permissions(u)
|
||||
|
||||
# Grant add:agents to the group, then batch-recompute
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
grant = PermissionGrant(
|
||||
group_id=group.id,
|
||||
permission=Permission.ADD_AGENTS,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
db_session.add(grant)
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
# Both users should now have the permission (plus implied read:agents)
|
||||
for u in (user_a, user_b):
|
||||
perms = UserManager.get_permissions(u)
|
||||
assert "add:agents" in perms, f"{u.id} missing add:agents: {perms}"
|
||||
assert "read:agents" in perms, f"{u.id} missing implied read:agents: {perms}"
|
||||
|
||||
# Soft-delete the grant and recompute — permission should be removed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_grant = (
|
||||
db_session.query(PermissionGrant)
|
||||
.filter_by(group_id=group.id, permission=Permission.ADD_AGENTS)
|
||||
.first()
|
||||
)
|
||||
assert db_grant is not None
|
||||
db_grant.is_deleted = True
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
for u in (user_a, user_b):
|
||||
perms = UserManager.get_permissions(u)
|
||||
assert "add:agents" not in perms, f"{u.id} still has add:agents: {perms}"
|
||||
@@ -1,30 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_new_group_gets_basic_permission(reset: None) -> None: # noqa: ARG001
|
||||
admin_user: DATestUser = UserManager.create(name="admin_for_basic_perm")
|
||||
|
||||
user_group = UserGroupManager.create(
|
||||
name="basic-perm-test-group",
|
||||
user_ids=[admin_user.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
permissions = UserGroupManager.get_permissions(
|
||||
user_group=user_group,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert (
|
||||
"basic" in permissions
|
||||
), f"New group should have 'basic' permission, got: {permissions}"
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Integration tests for default group assignment on user registration.
|
||||
|
||||
Verifies that:
|
||||
- The first registered user is assigned to the Admin default group
|
||||
- Subsequent registered users are assigned to the Basic default group
|
||||
- account_type is set to STANDARD for email/password registrations
|
||||
"""
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_default_group_assignment_on_registration(reset: None) -> None: # noqa: ARG001
|
||||
# Register first user — should become admin
|
||||
admin_user: DATestUser = UserManager.create(name="first_user")
|
||||
assert admin_user.role == UserRole.ADMIN
|
||||
|
||||
# Register second user — should become basic
|
||||
basic_user: DATestUser = UserManager.create(name="second_user")
|
||||
assert basic_user.role == UserRole.BASIC
|
||||
|
||||
# Fetch all groups including default ones
|
||||
all_groups = UserGroupManager.get_all(
|
||||
user_performing_action=admin_user,
|
||||
include_default=True,
|
||||
)
|
||||
|
||||
# Find the default Admin and Basic groups
|
||||
admin_group = next(
|
||||
(g for g in all_groups if g.name == "Admin" and g.is_default), None
|
||||
)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.name == "Basic" and g.is_default), None
|
||||
)
|
||||
assert admin_group is not None, "Admin default group not found"
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
|
||||
# Verify admin user is in Admin group and NOT in Basic group
|
||||
admin_group_user_ids = {str(u.id) for u in admin_group.users}
|
||||
basic_group_user_ids = {str(u.id) for u in basic_group.users}
|
||||
|
||||
assert (
|
||||
admin_user.id in admin_group_user_ids
|
||||
), "First user should be in Admin default group"
|
||||
assert (
|
||||
admin_user.id not in basic_group_user_ids
|
||||
), "First user should NOT be in Basic default group"
|
||||
|
||||
# Verify basic user is in Basic group and NOT in Admin group
|
||||
assert (
|
||||
basic_user.id in basic_group_user_ids
|
||||
), "Second user should be in Basic default group"
|
||||
assert (
|
||||
basic_user.id not in admin_group_user_ids
|
||||
), "Second user should NOT be in Admin default group"
|
||||
|
||||
# Verify account_type is STANDARD for both users via user listing API
|
||||
paginated_result = UserManager.get_user_page(
|
||||
user_performing_action=admin_user,
|
||||
page_num=0,
|
||||
page_size=10,
|
||||
)
|
||||
users_by_id = {str(u.id): u for u in paginated_result.items}
|
||||
|
||||
admin_snapshot = users_by_id.get(admin_user.id)
|
||||
basic_snapshot = users_by_id.get(basic_user.id)
|
||||
assert admin_snapshot is not None, "Admin user not found in user listing"
|
||||
assert basic_snapshot is not None, "Basic user not found in user listing"
|
||||
|
||||
assert (
|
||||
admin_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Admin user account_type should be STANDARD, got {admin_snapshot.account_type}"
|
||||
assert (
|
||||
basic_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Basic user account_type should be STANDARD, got {basic_snapshot.account_type}"
|
||||
@@ -1,176 +0,0 @@
|
||||
"""
|
||||
Unit tests for onyx.auth.permissions — pure logic and FastAPI dependency.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.permissions import ALL_PERMISSIONS
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.permissions import resolve_effective_permissions
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_effective_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveEffectivePermissions:
|
||||
def test_empty_set(self) -> None:
|
||||
assert resolve_effective_permissions(set()) == set()
|
||||
|
||||
def test_basic_no_implications(self) -> None:
|
||||
result = resolve_effective_permissions({"basic"})
|
||||
assert result == {"basic"}
|
||||
|
||||
def test_single_implication(self) -> None:
|
||||
result = resolve_effective_permissions({"add:agents"})
|
||||
assert result == {"add:agents", "read:agents"}
|
||||
|
||||
def test_manage_agents_implies_add_and_read(self) -> None:
|
||||
"""manage:agents directly maps to {add:agents, read:agents}."""
|
||||
result = resolve_effective_permissions({"manage:agents"})
|
||||
assert result == {"manage:agents", "add:agents", "read:agents"}
|
||||
|
||||
def test_manage_connectors_chain(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:connectors"})
|
||||
assert result == {"manage:connectors", "add:connectors", "read:connectors"}
|
||||
|
||||
def test_manage_document_sets(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:document_sets"})
|
||||
assert result == {
|
||||
"manage:document_sets",
|
||||
"read:document_sets",
|
||||
"read:connectors",
|
||||
}
|
||||
|
||||
def test_manage_user_groups_implies_all_reads(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:user_groups"})
|
||||
assert result == {
|
||||
"manage:user_groups",
|
||||
"read:connectors",
|
||||
"read:document_sets",
|
||||
"read:agents",
|
||||
"read:users",
|
||||
}
|
||||
|
||||
def test_admin_override(self) -> None:
|
||||
result = resolve_effective_permissions({"admin"})
|
||||
assert result == set(ALL_PERMISSIONS)
|
||||
|
||||
def test_admin_with_others(self) -> None:
|
||||
result = resolve_effective_permissions({"admin", "basic"})
|
||||
assert result == set(ALL_PERMISSIONS)
|
||||
|
||||
def test_multi_group_union(self) -> None:
|
||||
result = resolve_effective_permissions(
|
||||
{"add:agents", "manage:connectors", "basic"}
|
||||
)
|
||||
assert result == {
|
||||
"basic",
|
||||
"add:agents",
|
||||
"read:agents",
|
||||
"manage:connectors",
|
||||
"add:connectors",
|
||||
"read:connectors",
|
||||
}
|
||||
|
||||
def test_toggle_permission_no_implications(self) -> None:
|
||||
result = resolve_effective_permissions({"read:agent_analytics"})
|
||||
assert result == {"read:agent_analytics"}
|
||||
|
||||
def test_all_permissions_for_admin(self) -> None:
|
||||
result = resolve_effective_permissions({"admin"})
|
||||
assert len(result) == len(ALL_PERMISSIONS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_effective_permissions (expands implied at read time)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetEffectivePermissions:
|
||||
def test_expands_implied_permissions(self) -> None:
|
||||
"""Column stores only granted; get_effective_permissions expands implied."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["add:agents"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == {Permission.ADD_AGENTS, Permission.READ_AGENTS}
|
||||
|
||||
def test_admin_expands_to_all(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["admin"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == set(Permission)
|
||||
|
||||
def test_basic_stays_basic(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["basic"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == {Permission.BASIC_ACCESS}
|
||||
|
||||
def test_empty_column(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = []
|
||||
result = get_effective_permissions(user)
|
||||
assert result == set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# require_permission (FastAPI dependency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRequirePermission:
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_bypass(self) -> None:
|
||||
"""Admin stored in column should pass any permission check."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["admin"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_required_permission(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["manage:connectors"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implied_permission_passes(self) -> None:
|
||||
"""manage:connectors implies read:connectors at read time."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["manage:connectors"]
|
||||
|
||||
dep = require_permission(Permission.READ_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_permission_raises(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["basic"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await dep(user=user)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_permissions_fails(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = []
|
||||
|
||||
dep = require_permission(Permission.BASIC_ACCESS)
|
||||
with pytest.raises(OnyxError):
|
||||
await dep(user=user)
|
||||
@@ -1,29 +0,0 @@
|
||||
"""
|
||||
Unit tests for UserCreate schema dict methods.
|
||||
|
||||
Verifies that account_type is always included in create_update_dict
|
||||
and create_update_dict_superuser.
|
||||
"""
|
||||
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.db.enums import AccountType
|
||||
|
||||
|
||||
def test_create_update_dict_includes_default_account_type() -> None:
|
||||
uc = UserCreate(email="a@b.com", password="secret123")
|
||||
d = uc.create_update_dict()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
|
||||
|
||||
def test_create_update_dict_includes_explicit_account_type() -> None:
|
||||
uc = UserCreate(
|
||||
email="a@b.com", password="secret123", account_type=AccountType.SERVICE_ACCOUNT
|
||||
)
|
||||
d = uc.create_update_dict()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
|
||||
|
||||
def test_create_update_dict_superuser_includes_account_type() -> None:
|
||||
uc = UserCreate(email="a@b.com", password="secret123")
|
||||
d = uc.create_update_dict_superuser()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
@@ -300,6 +300,66 @@ class TestExtractContextFiles:
|
||||
assert result.file_texts == []
|
||||
assert result.total_token_count == 50
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_tool_metadata_file_id_matches_chat_history_file_id(
|
||||
self, mock_load: MagicMock
|
||||
) -> None:
|
||||
"""The file_id in tool metadata (from extract_context_files) and the
|
||||
file_id in chat history messages (from build_file_context) must
|
||||
agree, otherwise the LLM sees different IDs for the same file across
|
||||
turns.
|
||||
|
||||
In production, UserFile.id (UUID PK) differs from UserFile.file_id
|
||||
(file-store path). Both pathways should produce the same file_id
|
||||
(UserFile.id) for FileReaderTool."""
|
||||
from onyx.chat.chat_utils import build_file_context
|
||||
|
||||
user_file_uuid = uuid4()
|
||||
file_store_path = f"user_files/{user_file_uuid}/data.csv"
|
||||
|
||||
uf = UserFile(
|
||||
id=user_file_uuid,
|
||||
file_id=file_store_path,
|
||||
name="data.csv",
|
||||
token_count=100,
|
||||
file_type="text/csv",
|
||||
)
|
||||
|
||||
in_memory = InMemoryChatFile(
|
||||
file_id=file_store_path,
|
||||
content=b"col1,col2\na,b",
|
||||
file_type=ChatFileType.TABULAR,
|
||||
filename="data.csv",
|
||||
)
|
||||
|
||||
mock_load.return_value = [in_memory]
|
||||
|
||||
# Pathway 1: extract_context_files (project/persona context)
|
||||
result = extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
tool_metadata_file_id = result.file_metadata_for_tool[0].file_id
|
||||
|
||||
# Pathway 2: build_file_context (chat history path)
|
||||
# In convert_chat_history, tool_file_id comes from
|
||||
# file_descriptor["user_file_id"], which is str(UserFile.id)
|
||||
ctx = build_file_context(
|
||||
tool_file_id=str(user_file_uuid),
|
||||
filename="data.csv",
|
||||
file_type=ChatFileType.TABULAR,
|
||||
)
|
||||
chat_history_file_id = ctx.tool_metadata.file_id
|
||||
|
||||
# Both pathways must produce the same ID for the LLM
|
||||
assert tool_metadata_file_id == chat_history_file_id, (
|
||||
f"File ID mismatch: extract_context_files uses '{tool_metadata_file_id}' "
|
||||
f"but build_file_context uses '{chat_history_file_id}'."
|
||||
)
|
||||
|
||||
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
|
||||
def test_overflow_with_vector_db_disabled_provides_tool_metadata(self) -> None:
|
||||
"""When vector DB is disabled, overflow produces FileToolMetadata."""
|
||||
@@ -316,6 +376,128 @@ class TestExtractContextFiles:
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
assert result.file_metadata_for_tool[0].filename == "bigfile.txt"
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_metadata_only_files_not_counted_in_aggregate_tokens(
|
||||
self, mock_load: MagicMock
|
||||
) -> None:
|
||||
"""Metadata-only files (TABULAR) should not count toward the token budget."""
|
||||
text_file_id = str(uuid4())
|
||||
text_uf = _make_user_file(token_count=100, file_id=text_file_id)
|
||||
# TABULAR file with large token count — should be excluded from aggregate
|
||||
tabular_uf = _make_user_file(
|
||||
token_count=50000, name="huge.xlsx", file_id=str(uuid4())
|
||||
)
|
||||
tabular_uf.file_type = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
|
||||
mock_load.return_value = [
|
||||
_make_in_memory_file(file_id=text_file_id, content="text content"),
|
||||
InMemoryChatFile(
|
||||
file_id=str(tabular_uf.id),
|
||||
content=b"binary xlsx",
|
||||
file_type=ChatFileType.TABULAR,
|
||||
filename="huge.xlsx",
|
||||
),
|
||||
]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[text_uf, tabular_uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
# Text file fits (100 < 6000), so files should be loaded
|
||||
assert result.file_texts == ["text content"]
|
||||
# TABULAR file should appear as tool metadata, not in file_texts
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
assert result.file_metadata_for_tool[0].filename == "huge.xlsx"
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_metadata_only_files_loaded_as_tool_metadata(
|
||||
self, mock_load: MagicMock
|
||||
) -> None:
|
||||
"""When files fit, metadata-only files appear in file_metadata_for_tool."""
|
||||
text_file_id = str(uuid4())
|
||||
tabular_file_id = str(uuid4())
|
||||
text_uf = _make_user_file(token_count=100, file_id=text_file_id)
|
||||
tabular_uf = _make_user_file(
|
||||
token_count=500, name="data.csv", file_id=tabular_file_id
|
||||
)
|
||||
tabular_uf.file_type = "text/csv"
|
||||
|
||||
mock_load.return_value = [
|
||||
_make_in_memory_file(file_id=text_file_id, content="hello"),
|
||||
InMemoryChatFile(
|
||||
file_id=tabular_file_id,
|
||||
content=b"col1,col2\na,b",
|
||||
file_type=ChatFileType.TABULAR,
|
||||
filename="data.csv",
|
||||
),
|
||||
]
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[text_uf, tabular_uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.file_texts == ["hello"]
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
assert result.file_metadata_for_tool[0].filename == "data.csv"
|
||||
# TABULAR should not appear in file_metadata (that's for citation)
|
||||
assert all(m.filename != "data.csv" for m in result.file_metadata)
|
||||
|
||||
def test_overflow_with_vector_db_preserves_metadata_only_tool_metadata(
|
||||
self,
|
||||
) -> None:
|
||||
"""When text files overflow with vector DB enabled, metadata-only files
|
||||
should still be exposed via file_metadata_for_tool since they aren't
|
||||
in the vector DB and would otherwise be inaccessible."""
|
||||
text_uf = _make_user_file(token_count=7000, name="bigfile.txt")
|
||||
tabular_uf = _make_user_file(token_count=500, name="data.xlsx")
|
||||
tabular_uf.file_type = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[text_uf, tabular_uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
# Text files overflow → search filter enabled
|
||||
assert result.use_as_search_filter is True
|
||||
assert result.file_texts == []
|
||||
# TABULAR file should still be in tool metadata
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
assert result.file_metadata_for_tool[0].filename == "data.xlsx"
|
||||
|
||||
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
|
||||
def test_overflow_no_vector_db_includes_all_files_in_tool_metadata(self) -> None:
|
||||
"""When vector DB is disabled and files overflow, all files
|
||||
(both text and metadata-only) appear in file_metadata_for_tool."""
|
||||
text_uf = _make_user_file(token_count=7000, name="bigfile.txt")
|
||||
tabular_uf = _make_user_file(token_count=500, name="data.xlsx")
|
||||
tabular_uf.file_type = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
|
||||
result = extract_context_files(
|
||||
user_files=[text_uf, tabular_uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is False
|
||||
assert len(result.file_metadata_for_tool) == 2
|
||||
filenames = {m.filename for m in result.file_metadata_for_tool}
|
||||
assert filenames == {"bigfile.txt", "data.xlsx"}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Search filter + search_usage determination
|
||||
|
||||
173
backend/tests/unit/onyx/chat/test_emitter.py
Normal file
173
backend/tests/unit/onyx/chat/test_emitter.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Unit tests for the Emitter class.
|
||||
|
||||
All tests use the streaming mode (merged_queue required). Emitter has a single
|
||||
code path — no standalone bus.
|
||||
"""
|
||||
|
||||
import queue
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _placement(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
)
|
||||
|
||||
|
||||
def _packet(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Packet:
|
||||
"""Build a minimal valid packet with an OverallStop payload."""
|
||||
return Packet(
|
||||
placement=_placement(turn_index, tab_index, sub_turn_index),
|
||||
obj=OverallStop(stop_reason="test"),
|
||||
)
|
||||
|
||||
|
||||
def _make_emitter(model_idx: int = 0) -> tuple["Emitter", "queue.Queue"]:
|
||||
"""Return (emitter, queue) wired together."""
|
||||
mq: queue.Queue = queue.Queue()
|
||||
return Emitter(merged_queue=mq, model_idx=model_idx), mq
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queue routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterQueueRouting:
|
||||
def test_emit_lands_on_merged_queue(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet())
|
||||
assert not mq.empty()
|
||||
|
||||
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=1)
|
||||
emitter.emit(_packet())
|
||||
item = mq.get_nowait()
|
||||
assert isinstance(item, tuple)
|
||||
assert len(item) == 2
|
||||
|
||||
def test_multiple_packets_delivered_fifo(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
p1 = _packet(turn_index=0)
|
||||
p2 = _packet(turn_index=1)
|
||||
emitter.emit(p1)
|
||||
emitter.emit(p2)
|
||||
_, t1 = mq.get_nowait()
|
||||
_, t2 = mq.get_nowait()
|
||||
assert t1.placement.turn_index == 0
|
||||
assert t2.placement.turn_index == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# model_index tagging
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterModelIndexTagging:
|
||||
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
|
||||
"""N=1: default model_idx=0, so packet gets model_index=0."""
|
||||
emitter, mq = _make_emitter(model_idx=0)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 0
|
||||
|
||||
def test_model_idx_one_tags_packet(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=1)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 1
|
||||
|
||||
def test_model_idx_two_tags_packet(self) -> None:
|
||||
"""Boundary: third model in a 3-model run."""
|
||||
emitter, mq = _make_emitter(model_idx=2)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queue key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterQueueKey:
|
||||
def test_key_equals_model_idx(self) -> None:
|
||||
"""Drain loop uses the key to route packets; it must match model_idx."""
|
||||
emitter, mq = _make_emitter(model_idx=2)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 2
|
||||
|
||||
def test_n1_key_is_zero(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=0)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Placement field preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterPlacementPreservation:
|
||||
def test_turn_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(turn_index=5))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.turn_index == 5
|
||||
|
||||
def test_tab_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(tab_index=3))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.tab_index == 3
|
||||
|
||||
def test_sub_turn_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(sub_turn_index=2))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index == 2
|
||||
|
||||
def test_sub_turn_index_none_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(sub_turn_index=None))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index is None
|
||||
|
||||
def test_packet_obj_is_not_modified(self) -> None:
|
||||
"""The payload object must survive tagging untouched."""
|
||||
emitter, mq = _make_emitter()
|
||||
original_obj = OverallStop(stop_reason="sentinel")
|
||||
pkt = Packet(placement=_placement(), obj=original_obj)
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.obj is original_obj
|
||||
|
||||
def test_different_obj_types_are_handled(self) -> None:
|
||||
"""Any valid PacketObj type passes through correctly."""
|
||||
emitter, mq = _make_emitter()
|
||||
pkt = Packet(placement=_placement(), obj=ReasoningStart())
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert isinstance(tagged.obj, ReasoningStart)
|
||||
@@ -644,6 +644,92 @@ class TestConstructMessageHistory:
|
||||
assert "Project file 0 content" in project_message.message
|
||||
assert "Project file 1 content" in project_message.message
|
||||
|
||||
def test_file_metadata_for_tool_produces_message(self) -> None:
|
||||
"""When context_files has file_metadata_for_tool, a metadata listing
|
||||
message should be injected into the history."""
|
||||
system_prompt = create_message("System", MessageType.SYSTEM, 10)
|
||||
user_msg = create_message("Analyze the spreadsheet", MessageType.USER, 5)
|
||||
|
||||
context_files = ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=False,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=0,
|
||||
file_metadata_for_tool=[
|
||||
FileToolMetadata(
|
||||
file_id="xlsx-1",
|
||||
filename="report.xlsx",
|
||||
approx_char_count=100000,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=[user_msg],
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
available_tokens=1000,
|
||||
token_counter=_simple_token_counter,
|
||||
)
|
||||
|
||||
# Should have: system, tool_metadata_message, user
|
||||
assert len(result) == 3
|
||||
metadata_msg = result[1]
|
||||
assert metadata_msg.message_type == MessageType.USER
|
||||
assert "report.xlsx" in metadata_msg.message
|
||||
assert "xlsx-1" in metadata_msg.message
|
||||
|
||||
def test_metadata_only_and_text_files_both_present(self) -> None:
|
||||
"""When both text content and tool metadata are present, both messages
|
||||
should appear in the history."""
|
||||
system_prompt = create_message("System", MessageType.SYSTEM, 10)
|
||||
user_msg = create_message("Summarize everything", MessageType.USER, 5)
|
||||
|
||||
context_files = ExtractedContextFiles(
|
||||
file_texts=["Text file content here"],
|
||||
image_files=[],
|
||||
use_as_search_filter=False,
|
||||
total_token_count=100,
|
||||
file_metadata=[
|
||||
ContextFileMetadata(
|
||||
file_id="txt-1",
|
||||
filename="notes.txt",
|
||||
file_content="Text file content here",
|
||||
),
|
||||
],
|
||||
uncapped_token_count=100,
|
||||
file_metadata_for_tool=[
|
||||
FileToolMetadata(
|
||||
file_id="xlsx-1",
|
||||
filename="data.xlsx",
|
||||
approx_char_count=50000,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=[user_msg],
|
||||
reminder_message=None,
|
||||
context_files=context_files,
|
||||
available_tokens=2000,
|
||||
token_counter=_simple_token_counter,
|
||||
)
|
||||
|
||||
# Should have: system, context_files_message, tool_metadata_message, user
|
||||
assert len(result) == 4
|
||||
# Context files message (text content)
|
||||
assert "documents" in result[1].message
|
||||
assert "Text file content here" in result[1].message
|
||||
# Tool metadata message
|
||||
assert "data.xlsx" in result[2].message
|
||||
assert result[3] == user_msg
|
||||
|
||||
|
||||
def _simple_token_counter(text: str) -> int:
|
||||
"""Approximate token counter for tests (~4 chars per token)."""
|
||||
|
||||
768
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
768
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
@@ -0,0 +1,768 @@
|
||||
"""Unit tests for multi-model streaming validation and DB helpers.
|
||||
|
||||
These are pure unit tests — no real database or LLM calls required.
|
||||
The validation logic in handle_multi_model_stream fires before any external
|
||||
calls, so we can trigger it with lightweight mocks.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _restore_ee_version() -> Generator[None, None, None]:
|
||||
"""Reset EE global state after each test.
|
||||
|
||||
Importing onyx.chat.process_message triggers set_is_ee_based_on_env_variable()
|
||||
(via the celery client import chain). Without this fixture, the EE flag stays
|
||||
True for the rest of the session and breaks unrelated tests that mock Confluence
|
||||
or other connectors and assume EE is disabled.
|
||||
"""
|
||||
original = global_version._is_ee
|
||||
yield
|
||||
global_version._is_ee = original
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(**kwargs: Any) -> SendMessageRequest:
|
||||
defaults: dict[str, Any] = {
|
||||
"message": "hello",
|
||||
"chat_session_id": uuid4(),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SendMessageRequest(**defaults)
|
||||
|
||||
|
||||
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
|
||||
return LLMOverride(model_provider=provider, model_version=version)
|
||||
|
||||
|
||||
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
|
||||
"""Return the first item yielded by handle_multi_model_stream."""
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
|
||||
user = MagicMock()
|
||||
user.is_anonymous = False
|
||||
user.email = "test@example.com"
|
||||
db = MagicMock()
|
||||
|
||||
gen = handle_multi_model_stream(req, user, db, overrides)
|
||||
return next(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_multi_model_stream — validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunMultiModelStreamValidation:
|
||||
def test_single_override_yields_error(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_four_overrides_yields_error(self) -> None:
|
||||
"""4 overrides exceeds maximum — yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_zero_overrides_yields_error(self) -> None:
|
||||
"""Empty override list yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(req, [])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_deep_research_yields_error(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
|
||||
req = _make_request(deep_research=True)
|
||||
result = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "not supported" in result.error
|
||||
|
||||
def test_exactly_two_overrides_is_minimum(self) -> None:
|
||||
"""Boundary: 1 override yields error, 2 overrides passes validation."""
|
||||
req = _make_request()
|
||||
# 1 override must yield a StreamingError
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(
|
||||
result, StreamingError
|
||||
), "1 override should yield StreamingError"
|
||||
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
|
||||
# missing session, that's OK — validation itself passed)
|
||||
try:
|
||||
result2 = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
if isinstance(result2, StreamingError) and "2-3" in result2.error:
|
||||
pytest.fail(
|
||||
f"2 overrides should pass validation, got StreamingError: {result2.error}"
|
||||
)
|
||||
except Exception:
|
||||
pass # Any non-validation error means validation passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_preferred_response — validation (mocked db)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetPreferredResponseValidation:
|
||||
def test_user_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
db.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=999, preferred_assistant_message_id=1
|
||||
)
|
||||
|
||||
def test_wrong_message_type(self) -> None:
|
||||
"""Cannot set preferred response on a non-USER message."""
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.ASSISTANT # wrong type
|
||||
|
||||
db.get.return_value = user_msg
|
||||
|
||||
with pytest.raises(ValueError, match="not a user message"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
# First call returns user_msg, second call (for assistant) returns None
|
||||
db.get.side_effect = [user_msg, None]
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_not_child_of_user(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 999 # different parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
with pytest.raises(ValueError, match="not a child"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_valid_call_sets_preferred_response_id(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 1 # correct parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
|
||||
|
||||
assert user_msg.preferred_response_id == 2
|
||||
assert user_msg.latest_child_message_id == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMOverride — display_name field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMOverrideDisplayName:
|
||||
def test_display_name_defaults_none(self) -> None:
|
||||
override = LLMOverride(model_provider="openai", model_version="gpt-4")
|
||||
assert override.display_name is None
|
||||
|
||||
def test_display_name_set(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="openai",
|
||||
model_version="gpt-4",
|
||||
display_name="GPT-4 Turbo",
|
||||
)
|
||||
assert override.display_name == "GPT-4 Turbo"
|
||||
|
||||
def test_display_name_serializes(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="anthropic",
|
||||
model_version="claude-opus-4-6",
|
||||
display_name="Claude Opus",
|
||||
)
|
||||
d = override.model_dump()
|
||||
assert d["display_name"] == "Claude Opus"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_models — drain loop behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_setup(n_models: int = 1) -> MagicMock:
|
||||
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
|
||||
setup = MagicMock()
|
||||
setup.llms = [MagicMock() for _ in range(n_models)]
|
||||
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
|
||||
setup.reserved_token_count = 100
|
||||
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
|
||||
# constructors inside _run_model — must be typed correctly for Pydantic.
|
||||
setup.new_msg_req.deep_research = False
|
||||
setup.new_msg_req.internal_search_filters = None
|
||||
setup.new_msg_req.allowed_tool_ids = None
|
||||
setup.new_msg_req.include_citations = True
|
||||
setup.search_params.project_id_filter = None
|
||||
setup.search_params.persona_id_filter = None
|
||||
setup.bypass_acl = False
|
||||
setup.slack_context = None
|
||||
setup.available_files.user_file_ids = []
|
||||
setup.available_files.chat_file_ids = []
|
||||
setup.forced_tool_id = None
|
||||
setup.simple_chat_history = []
|
||||
setup.chat_session.id = uuid4()
|
||||
setup.user_message.id = None
|
||||
setup.custom_tool_additional_headers = None
|
||||
setup.mcp_headers = None
|
||||
return setup
|
||||
|
||||
|
||||
def _run_models_collect(setup: MagicMock) -> list:
|
||||
"""Drive _run_models to completion and return all yielded items."""
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
return list(_run_models(setup, MagicMock(), MagicMock()))
|
||||
|
||||
|
||||
class TestRunModels:
|
||||
"""Tests for the _run_models worker-thread drain loop.
|
||||
|
||||
All external dependencies (LLM, DB, tools) are patched out. Worker threads
|
||||
still run but return immediately since run_llm_loop is mocked.
|
||||
"""
|
||||
|
||||
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
|
||||
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
|
||||
|
||||
def emit_stop(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(stop_reason="complete"),
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert len(stops) == 1
|
||||
stop_obj = stops[0].obj
|
||||
assert isinstance(stop_obj, OverallStop)
|
||||
assert stop_obj.stop_reason == "complete"
|
||||
|
||||
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
|
||||
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 0
|
||||
|
||||
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
|
||||
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
# _model_idx is set by _run_model based on position in setup.llms
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=2))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 2
|
||||
indices = {p.placement.model_index for p in reasoning}
|
||||
assert indices == {0, 1}
|
||||
|
||||
def test_model_error_yields_streaming_error(self) -> None:
|
||||
"""An exception inside a worker thread is surfaced as a StreamingError."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("intentional test failure")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].error_code == "MODEL_ERROR"
|
||||
assert "intentional test failure" in errors[0].error
|
||||
|
||||
def test_one_model_error_does_not_stop_other_models(self) -> None:
|
||||
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
|
||||
setup = _make_setup(n_models=2)
|
||||
|
||||
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
|
||||
if kwargs["llm"] is setup.llms[0]:
|
||||
raise RuntimeError("model 0 failed")
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=fail_model_0_succeed_model_1,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(setup)
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 1
|
||||
|
||||
def test_cancellation_yields_user_cancelled_stop(self) -> None:
|
||||
"""If check_is_connected returns False, drain loop emits user_cancelled."""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(setup)
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert any(
|
||||
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
|
||||
for s in stops
|
||||
)
|
||||
|
||||
def test_stop_button_calls_completion_for_all_models(self) -> None:
|
||||
"""llm_loop_completion_handle must be called for all models when the stop button fires.
|
||||
|
||||
Regression test for the disconnect-cleanup bug: the old
|
||||
run_chat_loop_with_state_containers always called completion_callback in
|
||||
its finally block (even on disconnect) so the DB message was updated from
|
||||
the TERMINATED placeholder to a partial answer. The new _run_models must
|
||||
replicate this — otherwise the integration test
|
||||
test_send_message_disconnect_and_cleanup fails because the message stays
|
||||
as "Response was terminated prior to completion, try regenerating."
|
||||
"""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3)
|
||||
|
||||
setup = _make_setup(n_models=2)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
# Must be called once per model, not zero times
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_called_for_each_successful_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be called once per model that succeeded."""
|
||||
setup = _make_setup(n_models=2)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_not_called_for_failed_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be skipped for a model that raised."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("fail")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
mock_handle.assert_not_called()
|
||||
|
||||
def test_http_disconnect_completion_via_generator_exit(self) -> None:
|
||||
"""GeneratorExit from HTTP disconnect triggers worker self-completion.
|
||||
|
||||
When the HTTP client closes the connection, Starlette throws GeneratorExit
|
||||
into the stream generator. The finally block sets drain_done (signalling
|
||||
emitters to stop blocking) and calls executor.shutdown(wait=False) so the
|
||||
server thread is never blocked. Worker threads detect drain_done.is_set()
|
||||
after run_llm_loop completes and self-persist the result via
|
||||
llm_loop_completion_handle using their own DB session.
|
||||
|
||||
This is the primary regression for test_send_message_disconnect_and_cleanup:
|
||||
the integration test disconnects mid-stream and expects the DB message to be
|
||||
updated from the TERMINATED placeholder to the real response.
|
||||
"""
|
||||
import threading
|
||||
|
||||
# Signals the worker to unblock from run_llm_loop after gen.close() returns.
|
||||
# This guarantees drain_done is set BEFORE the worker returns from run_llm_loop,
|
||||
# so the self-completion path (drain_done.is_set() check) is always taken.
|
||||
disconnect_received = threading.Event()
|
||||
# Set by the llm_loop_completion_handle mock when called.
|
||||
completion_called = threading.Event()
|
||||
|
||||
def emit_then_complete(**kwargs: Any) -> None:
|
||||
"""Emit one packet (to give the drain loop a yield point), then block
|
||||
until the main thread signals that gen.close() has been called. This
|
||||
ensures drain_done is set before we return so model_succeeded is checked
|
||||
against a set drain_done — no race condition.
|
||||
"""
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
disconnect_received.wait(timeout=5)
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=emit_then_complete,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
# cast to Generator so .close() is available; _run_models returns
|
||||
# AnswerStream (= Iterator) but the actual object is always a generator.
|
||||
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
|
||||
# Advance to the first yielded packet — generator suspends at `yield item`.
|
||||
first = next(gen)
|
||||
assert isinstance(first, Packet)
|
||||
# Simulate Starlette closing the stream on HTTP client disconnect.
|
||||
# GeneratorExit is thrown at the `yield item` suspension point.
|
||||
gen.close()
|
||||
# Unblock the worker now that drain_done has been set by gen.close().
|
||||
disconnect_received.set()
|
||||
|
||||
# Worker self-completes asynchronously (executor.shutdown(wait=False)).
|
||||
# Wait here, inside the patch context, so that get_session_with_current_tenant
|
||||
# and llm_loop_completion_handle mocks are still active when the worker calls them.
|
||||
assert completion_called.wait(
|
||||
timeout=5
|
||||
), "worker must self-complete via drain_done within 5 seconds"
|
||||
assert (
|
||||
mock_handle.call_count == 1
|
||||
), "completion handle must be called once for the successful model"
|
||||
|
||||
def test_b1_race_disconnect_handler_completes_already_finished_model(self) -> None:
|
||||
"""B1 regression: model finishes BEFORE GeneratorExit fires.
|
||||
|
||||
The worker exits _run_model with drain_done.is_set()=False and skips
|
||||
self-completion. When gen.close() fires afterward, the finally else-branch
|
||||
must detect model_succeeded=True and call llm_loop_completion_handle itself.
|
||||
|
||||
Contrast with test_http_disconnect_completion_via_generator_exit, which
|
||||
tests the opposite ordering (worker finishes AFTER disconnect).
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
completion_called = threading.Event()
|
||||
|
||||
def emit_and_return_immediately(**kwargs: Any) -> None:
|
||||
# Emit one packet so the drain loop has something to yield, then return
|
||||
# immediately — no blocking. The worker will be done in microseconds.
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=emit_and_return_immediately,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
|
||||
first = next(gen)
|
||||
assert isinstance(first, Packet)
|
||||
|
||||
# Give the worker thread time to finish completely (emit + return +
|
||||
# finally + self-completion check). It does almost no work, so 100 ms
|
||||
# is far more than enough while still keeping the test fast.
|
||||
time.sleep(0.1)
|
||||
|
||||
# Now close — worker is already done, so else-branch handles completion.
|
||||
gen.close()
|
||||
|
||||
assert completion_called.wait(
|
||||
timeout=5
|
||||
), "disconnect handler must call completion for a model that already finished"
|
||||
assert mock_handle.call_count == 1, "completion must be called exactly once"
|
||||
|
||||
def test_stop_button_does_not_call_completion_for_errored_model(self) -> None:
|
||||
"""B2 regression: stop-button must NOT call completion for an errored model.
|
||||
|
||||
When model 0 raises an exception, its reserved ChatMessage must not be
|
||||
saved with 'stopped by user' — that message is wrong for a model that
|
||||
errored. llm_loop_completion_handle must only be called for non-errored
|
||||
models when the stop button fires.
|
||||
"""
|
||||
|
||||
def fail_model_0(**kwargs: Any) -> None:
|
||||
if kwargs["llm"] is setup.llms[0]:
|
||||
raise RuntimeError("model 0 errored")
|
||||
# Model 1: run forever (stop button fires before it finishes)
|
||||
time.sleep(10)
|
||||
|
||||
setup = _make_setup(n_models=2)
|
||||
# Return False immediately so the stop-button path fires while model 1
|
||||
# is still sleeping (model 0 has already errored by then).
|
||||
setup.check_is_connected = lambda: False
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
# Completion must NOT be called for model 0 (it errored).
|
||||
# It MAY be called for model 1 (still in-flight when stop fired).
|
||||
for call in mock_handle.call_args_list:
|
||||
assert (
|
||||
call.kwargs.get("llm") is not setup.llms[0]
|
||||
), "llm_loop_completion_handle must not be called for the errored model"
|
||||
|
||||
def test_external_state_container_used_for_model_zero(self) -> None:
|
||||
"""When provided, external_state_container is used as state_containers[0]."""
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
external = ChatStateContainer()
|
||||
setup = _make_setup(n_models=1)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
list(
|
||||
_run_models(
|
||||
setup, MagicMock(), MagicMock(), external_state_container=external
|
||||
)
|
||||
)
|
||||
|
||||
# The state_container kwarg passed to run_llm_loop must be the external one
|
||||
call_kwargs = mock_llm.call_args.kwargs
|
||||
assert call_kwargs["state_container"] is external
|
||||
@@ -1,176 +0,0 @@
|
||||
"""
|
||||
Unit tests for assign_user_to_default_groups__no_commit in onyx.db.users.
|
||||
|
||||
Covers:
|
||||
1. Standard/service-account users get assigned to the correct default group
|
||||
2. BOT, EXT_PERM_USER, ANONYMOUS account types are skipped
|
||||
3. Missing default group raises RuntimeError
|
||||
4. Already-in-group is a no-op
|
||||
5. IntegrityError race condition is handled gracefully
|
||||
6. The function never commits the session
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
|
||||
|
||||
def _mock_user(
|
||||
account_type: AccountType = AccountType.STANDARD,
|
||||
email: str = "test@example.com",
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = uuid4()
|
||||
user.email = email
|
||||
user.account_type = account_type
|
||||
return user
|
||||
|
||||
|
||||
def _mock_group(name: str = "Basic", group_id: int = 1) -> MagicMock:
|
||||
group = MagicMock()
|
||||
group.id = group_id
|
||||
group.name = name
|
||||
group.is_default = True
|
||||
return group
|
||||
|
||||
|
||||
def _make_query_chain(first_return: object = None) -> MagicMock:
|
||||
"""Returns a mock that supports .filter(...).filter(...).first() chaining."""
|
||||
chain = MagicMock()
|
||||
chain.filter.return_value = chain
|
||||
chain.first.return_value = first_return
|
||||
return chain
|
||||
|
||||
|
||||
def _setup_db_session(
|
||||
group_result: object = None,
|
||||
membership_result: object = None,
|
||||
) -> MagicMock:
|
||||
"""Create a db_session mock that routes query(UserGroup) and query(User__UserGroup)."""
|
||||
db_session = MagicMock()
|
||||
|
||||
group_chain = _make_query_chain(group_result)
|
||||
membership_chain = _make_query_chain(membership_result)
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model is UserGroup:
|
||||
return group_chain
|
||||
if model is User__UserGroup:
|
||||
return membership_chain
|
||||
return MagicMock()
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
return db_session
|
||||
|
||||
|
||||
def test_standard_user_assigned_to_basic_group() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.STANDARD)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=False)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
added = db_session.add.call_args[0][0]
|
||||
assert isinstance(added, User__UserGroup)
|
||||
assert added.user_id == user.id
|
||||
assert added.user_group_id == group.id
|
||||
db_session.flush.assert_called_once()
|
||||
|
||||
|
||||
def test_admin_user_assigned_to_admin_group() -> None:
|
||||
group = _mock_group("Admin", group_id=2)
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.STANDARD)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=True)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
added = db_session.add.call_args[0][0]
|
||||
assert isinstance(added, User__UserGroup)
|
||||
assert added.user_group_id == group.id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"account_type",
|
||||
[AccountType.BOT, AccountType.EXT_PERM_USER, AccountType.ANONYMOUS],
|
||||
)
|
||||
def test_excluded_account_types_skipped(account_type: AccountType) -> None:
|
||||
db_session = MagicMock()
|
||||
user = _mock_user(account_type)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.query.assert_not_called()
|
||||
db_session.add.assert_not_called()
|
||||
|
||||
|
||||
def test_service_account_not_skipped() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.SERVICE_ACCOUNT)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=False)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
|
||||
|
||||
def test_missing_default_group_raises_error() -> None:
|
||||
db_session = _setup_db_session(group_result=None)
|
||||
user = _mock_user()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Default group .* not found"):
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
|
||||
def test_already_in_group_is_noop() -> None:
|
||||
group = _mock_group("Basic")
|
||||
existing_membership = MagicMock()
|
||||
db_session = _setup_db_session(
|
||||
group_result=group, membership_result=existing_membership
|
||||
)
|
||||
user = _mock_user()
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.add.assert_not_called()
|
||||
db_session.begin_nested.assert_not_called()
|
||||
|
||||
|
||||
def test_integrity_error_race_condition_handled() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
db_session.flush.side_effect = IntegrityError(None, None, Exception("duplicate"))
|
||||
user = _mock_user()
|
||||
|
||||
# Should not raise
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
savepoint.rollback.assert_called_once()
|
||||
|
||||
|
||||
def test_no_commit_called_on_successful_assignment() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user()
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.commit.assert_not_called()
|
||||
@@ -3,7 +3,6 @@ from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import UserGroupInfo
|
||||
|
||||
@@ -26,7 +25,6 @@ def _mock_user(
|
||||
user.updated_at = updated_at or datetime.datetime(
|
||||
2025, 6, 15, tzinfo=datetime.timezone.utc
|
||||
)
|
||||
user.account_type = AccountType.STANDARD
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for memory tool streaming packet emissions."""
|
||||
|
||||
from queue import Queue
|
||||
import queue
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -18,9 +18,13 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def emitter() -> Emitter:
|
||||
bus: Queue = Queue()
|
||||
return Emitter(bus)
|
||||
def emitter_queue() -> queue.Queue:
|
||||
return queue.Queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def emitter(emitter_queue: queue.Queue) -> Emitter:
|
||||
return Emitter(merged_queue=emitter_queue)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -53,24 +57,27 @@ class TestMemoryToolEmitStart:
|
||||
def test_emit_start_emits_memory_tool_start_packet(
|
||||
self,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
) -> None:
|
||||
memory_tool.emit_start(placement)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolStart)
|
||||
assert packet.placement == placement
|
||||
assert packet.placement is not None
|
||||
assert packet.placement.turn_index == placement.turn_index
|
||||
assert packet.placement.tab_index == placement.tab_index
|
||||
assert packet.placement.model_index == 0 # emitter stamps model_index=0
|
||||
|
||||
def test_emit_start_with_different_placement(
|
||||
self,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
) -> None:
|
||||
placement = Placement(turn_index=2, tab_index=1)
|
||||
memory_tool.emit_start(placement)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert packet.placement.turn_index == 2
|
||||
assert packet.placement.tab_index == 1
|
||||
|
||||
@@ -81,7 +88,7 @@ class TestMemoryToolRun:
|
||||
self,
|
||||
mock_process: MagicMock,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
override_kwargs: MemoryToolOverrideKwargs,
|
||||
) -> None:
|
||||
@@ -93,21 +100,19 @@ class TestMemoryToolRun:
|
||||
memory="User prefers Python",
|
||||
)
|
||||
|
||||
# The delta packet should be in the queue
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolDelta)
|
||||
assert packet.obj.memory_text == "User prefers Python"
|
||||
assert packet.obj.operation == "add"
|
||||
assert packet.obj.memory_id is None
|
||||
assert packet.obj.index is None
|
||||
assert packet.placement == placement
|
||||
|
||||
@patch("onyx.tools.tool_implementations.memory.memory_tool.process_memory_update")
|
||||
def test_run_emits_delta_for_update_operation(
|
||||
self,
|
||||
mock_process: MagicMock,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
override_kwargs: MemoryToolOverrideKwargs,
|
||||
) -> None:
|
||||
@@ -119,7 +124,7 @@ class TestMemoryToolRun:
|
||||
memory="User prefers light mode",
|
||||
)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolDelta)
|
||||
assert packet.obj.memory_text == "User prefers light mode"
|
||||
assert packet.obj.operation == "update"
|
||||
|
||||
@@ -1,153 +0,0 @@
|
||||
import { Form, Formik } from "formik";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { createApiKey, updateApiKey } from "./lib";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { FormikField } from "@/refresh-components/form/FormikField";
|
||||
import { FormField } from "@/refresh-components/form/FormField";
|
||||
import { USER_ROLE_LABELS, UserRole } from "@/lib/types";
|
||||
import { APIKey } from "./types";
|
||||
import { SvgKey } from "@opal/icons";
|
||||
|
||||
export interface OnyxApiKeyFormProps {
|
||||
onClose: () => void;
|
||||
onCreateApiKey: (apiKey: APIKey) => void;
|
||||
apiKey?: APIKey;
|
||||
}
|
||||
|
||||
export default function OnyxApiKeyForm({
|
||||
onClose,
|
||||
onCreateApiKey,
|
||||
apiKey,
|
||||
}: OnyxApiKeyFormProps) {
|
||||
const isUpdate = apiKey !== undefined;
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content width="sm" height="lg">
|
||||
<Modal.Header
|
||||
icon={SvgKey}
|
||||
title={isUpdate ? "Update API Key" : "Create a new API Key"}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Formik
|
||||
initialValues={{
|
||||
name: apiKey?.api_key_name || "",
|
||||
role: apiKey?.api_key_role || UserRole.BASIC.toString(),
|
||||
}}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
|
||||
// Prepare the payload with the UserRole
|
||||
const payload = {
|
||||
...values,
|
||||
role: values.role as UserRole, // Assign the role directly as a UserRole type
|
||||
};
|
||||
|
||||
let response;
|
||||
if (isUpdate) {
|
||||
response = await updateApiKey(apiKey.api_key_id, payload);
|
||||
} else {
|
||||
response = await createApiKey(payload);
|
||||
}
|
||||
formikHelpers.setSubmitting(false);
|
||||
if (response.ok) {
|
||||
toast.success(
|
||||
isUpdate
|
||||
? "Successfully updated API key!"
|
||||
: "Successfully created API key!"
|
||||
);
|
||||
if (!isUpdate) {
|
||||
onCreateApiKey(await response.json());
|
||||
}
|
||||
onClose();
|
||||
} else {
|
||||
const responseJson = await response.json();
|
||||
const errorMsg = responseJson.detail || responseJson.message;
|
||||
toast.error(
|
||||
isUpdate
|
||||
? `Error updating API key - ${errorMsg}`
|
||||
: `Error creating API key - ${errorMsg}`
|
||||
);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form className="w-full overflow-visible">
|
||||
<Modal.Body>
|
||||
<Text as="p">
|
||||
Choose a memorable name for your API key. This is optional and
|
||||
can be added or changed later!
|
||||
</Text>
|
||||
|
||||
<FormikField<string>
|
||||
name="name"
|
||||
render={(field, helper, _meta, state) => (
|
||||
<FormField name="name" state={state} className="w-full">
|
||||
<FormField.Label>Name (optional):</FormField.Label>
|
||||
<FormField.Control>
|
||||
<InputTypeIn
|
||||
{...field}
|
||||
placeholder=""
|
||||
onClear={() => helper.setValue("")}
|
||||
showClearButton={false}
|
||||
/>
|
||||
</FormField.Control>
|
||||
</FormField>
|
||||
)}
|
||||
/>
|
||||
|
||||
<FormikField<string>
|
||||
name="role"
|
||||
render={(field, helper, _meta, state) => (
|
||||
<FormField name="role" state={state} className="w-full">
|
||||
<FormField.Label>Role:</FormField.Label>
|
||||
<FormField.Control>
|
||||
<InputSelect
|
||||
value={field.value}
|
||||
onValueChange={(value) => helper.setValue(value)}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select a role" />
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item
|
||||
value={UserRole.LIMITED.toString()}
|
||||
>
|
||||
{USER_ROLE_LABELS[UserRole.LIMITED]}
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value={UserRole.BASIC.toString()}>
|
||||
{USER_ROLE_LABELS[UserRole.BASIC]}
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value={UserRole.ADMIN.toString()}>
|
||||
{USER_ROLE_LABELS[UserRole.ADMIN]}
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</FormField.Control>
|
||||
<FormField.Description>
|
||||
Select the role for this API key. Limited has access to
|
||||
simple public APIs. Basic has access to regular user
|
||||
APIs. Admin has access to admin level APIs.
|
||||
</FormField.Description>
|
||||
</FormField>
|
||||
)}
|
||||
/>
|
||||
</Modal.Body>
|
||||
|
||||
<Modal.Footer>
|
||||
<Disabled disabled={isSubmitting}>
|
||||
<Button type="submit">
|
||||
{isUpdate ? "Update" : "Create"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
</Modal.Footer>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
import { APIKeyArgs, APIKey } from "./types";
|
||||
|
||||
export const createApiKey = async (apiKeyArgs: APIKeyArgs) => {
|
||||
return fetch("/api/admin/api-key", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(apiKeyArgs),
|
||||
});
|
||||
};
|
||||
|
||||
export const regenerateApiKey = async (apiKey: APIKey) => {
|
||||
return fetch(`/api/admin/api-key/${apiKey.api_key_id}/regenerate`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const updateApiKey = async (
|
||||
apiKeyId: number,
|
||||
apiKeyArgs: APIKeyArgs
|
||||
) => {
|
||||
return fetch(`/api/admin/api-key/${apiKeyId}`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(apiKeyArgs),
|
||||
});
|
||||
};
|
||||
|
||||
export const deleteApiKey = async (apiKeyId: number) => {
|
||||
return fetch(`/api/admin/api-key/${apiKeyId}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
};
|
||||
@@ -1,259 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import {
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
Table,
|
||||
} from "@/components/ui/table";
|
||||
import Title from "@/components/ui/title";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { useState } from "react";
|
||||
import { DeleteButton } from "@/components/DeleteButton";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { Spinner } from "@/components/Spinner";
|
||||
import { deleteApiKey, regenerateApiKey } from "@/app/admin/api-key/lib";
|
||||
import OnyxApiKeyForm from "@/app/admin/api-key/OnyxApiKeyForm";
|
||||
import {
|
||||
APIKey,
|
||||
DISCORD_SERVICE_API_KEY_NAME,
|
||||
} from "@/app/admin/api-key/types";
|
||||
import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { Button } from "@opal/components";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import { useCloudSubscription } from "@/hooks/useCloudSubscription";
|
||||
import { useBillingInformation } from "@/hooks/useBillingInformation";
|
||||
import { BillingStatus, hasActiveSubscription } from "@/lib/billing/interfaces";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTES.API_KEYS;
|
||||
|
||||
function Main() {
|
||||
const {
|
||||
data: apiKeys,
|
||||
isLoading,
|
||||
error,
|
||||
} = useSWR<APIKey[]>("/api/admin/api-key", errorHandlingFetcher);
|
||||
|
||||
const canCreateKeys = useCloudSubscription();
|
||||
const { data: billingData } = useBillingInformation();
|
||||
const isTrialing =
|
||||
billingData !== undefined &&
|
||||
hasActiveSubscription(billingData) &&
|
||||
billingData.status === BillingStatus.TRIALING;
|
||||
|
||||
const [fullApiKey, setFullApiKey] = useState<string | null>(null);
|
||||
const [keyIsGenerating, setKeyIsGenerating] = useState(false);
|
||||
const [showCreateUpdateForm, setShowCreateUpdateForm] = useState(false);
|
||||
const [selectedApiKey, setSelectedApiKey] = useState<APIKey | undefined>();
|
||||
|
||||
const handleEdit = (apiKey: APIKey) => {
|
||||
setSelectedApiKey(apiKey);
|
||||
setShowCreateUpdateForm(true);
|
||||
};
|
||||
|
||||
if (isLoading) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
if (!apiKeys || error) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Failed to fetch API Keys"
|
||||
errorMsg={error?.info?.detail || error.toString()}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Filter out the discord service key from the displayed list
|
||||
const filteredApiKeys = apiKeys.filter(
|
||||
(key) => key.api_key_name !== DISCORD_SERVICE_API_KEY_NAME
|
||||
);
|
||||
|
||||
const introSection = (
|
||||
<div className="flex flex-col items-start gap-4">
|
||||
{isTrialing && (
|
||||
<Message
|
||||
static
|
||||
warning
|
||||
close={false}
|
||||
className="w-full"
|
||||
text="Upgrade to a paid plan to create API keys."
|
||||
description="Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
|
||||
/>
|
||||
)}
|
||||
<Text as="p">
|
||||
API Keys allow you to access Onyx APIs programmatically.
|
||||
{canCreateKeys
|
||||
? " Click the button below to generate a new API Key."
|
||||
: ""}
|
||||
</Text>
|
||||
{canCreateKeys ? (
|
||||
<CreateButton onClick={() => setShowCreateUpdateForm(true)}>
|
||||
Create API Key
|
||||
</CreateButton>
|
||||
) : isTrialing ? (
|
||||
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
|
||||
if (filteredApiKeys.length === 0) {
|
||||
return (
|
||||
<div>
|
||||
{introSection}
|
||||
|
||||
{showCreateUpdateForm && (
|
||||
<OnyxApiKeyForm
|
||||
onCreateApiKey={(apiKey) => {
|
||||
setFullApiKey(apiKey.api_key);
|
||||
}}
|
||||
onClose={() => {
|
||||
setShowCreateUpdateForm(false);
|
||||
setSelectedApiKey(undefined);
|
||||
mutate("/api/admin/api-key");
|
||||
}}
|
||||
apiKey={selectedApiKey}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Modal open={!!fullApiKey}>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header
|
||||
title="New API Key"
|
||||
icon={SvgKey}
|
||||
onClose={() => setFullApiKey(null)}
|
||||
description="Make sure you copy your new API key. You won't be able to see this key again."
|
||||
/>
|
||||
<Modal.Body>
|
||||
<Text as="p" className="break-all flex-1">
|
||||
{fullApiKey}
|
||||
</Text>
|
||||
<CopyIconButton getCopyText={() => fullApiKey!} />
|
||||
</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
|
||||
{keyIsGenerating && <Spinner />}
|
||||
|
||||
{introSection}
|
||||
|
||||
{canCreateKeys && (
|
||||
<>
|
||||
<Separator />
|
||||
|
||||
<Title className="mt-6">Existing API Keys</Title>
|
||||
<Table className="overflow-visible">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>API Key</TableHead>
|
||||
<TableHead>Role</TableHead>
|
||||
<TableHead>Regenerate</TableHead>
|
||||
<TableHead>Delete</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{filteredApiKeys.map((apiKey) => (
|
||||
<TableRow key={apiKey.api_key_id}>
|
||||
<TableCell>
|
||||
<Button
|
||||
prominence="internal"
|
||||
onClick={() => handleEdit(apiKey)}
|
||||
icon={SvgEdit}
|
||||
>
|
||||
{apiKey.api_key_name || "null"}
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell className="max-w-64">
|
||||
{apiKey.api_key_display}
|
||||
</TableCell>
|
||||
<TableCell className="max-w-64">
|
||||
{apiKey.api_key_role.toUpperCase()}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Button
|
||||
prominence="internal"
|
||||
icon={SvgRefreshCw}
|
||||
onClick={async () => {
|
||||
setKeyIsGenerating(true);
|
||||
const response = await regenerateApiKey(apiKey);
|
||||
setKeyIsGenerating(false);
|
||||
if (!response.ok) {
|
||||
const errorMsg = await response.text();
|
||||
toast.error(
|
||||
`Failed to regenerate API Key: ${errorMsg}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
const newKey = (await response.json()) as APIKey;
|
||||
setFullApiKey(newKey.api_key);
|
||||
mutate("/api/admin/api-key");
|
||||
}}
|
||||
>
|
||||
Refresh
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<DeleteButton
|
||||
onClick={async () => {
|
||||
const response = await deleteApiKey(apiKey.api_key_id);
|
||||
if (!response.ok) {
|
||||
const errorMsg = await response.text();
|
||||
toast.error(`Failed to delete API Key: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
mutate("/api/admin/api-key");
|
||||
}}
|
||||
/>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
{showCreateUpdateForm && (
|
||||
<OnyxApiKeyForm
|
||||
onCreateApiKey={(apiKey) => {
|
||||
setFullApiKey(apiKey.api_key);
|
||||
}}
|
||||
onClose={() => {
|
||||
setShowCreateUpdateForm(false);
|
||||
setSelectedApiKey(undefined);
|
||||
mutate("/api/admin/api-key");
|
||||
}}
|
||||
apiKey={selectedApiKey}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header title={route.title} icon={route.icon} separator />
|
||||
<SettingsLayouts.Body>
|
||||
<Main />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
1
web/src/app/admin/service-accounts/page.tsx
Normal file
1
web/src/app/admin/service-accounts/page.tsx
Normal file
@@ -0,0 +1 @@
|
||||
export { default } from "@/refresh-pages/admin/ServiceAccountsPage";
|
||||
@@ -182,7 +182,8 @@ export async function* sendMessage({
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
const data = await response.json().catch(() => ({}));
|
||||
throw new Error(data.detail ?? `HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
yield* handleSSEStream<PacketType>(response, signal);
|
||||
|
||||
@@ -4,7 +4,7 @@ import { useCallback } from "react";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { AccountType, UserStatus } from "@/lib/types";
|
||||
import { UserStatus } from "@/lib/types";
|
||||
import type { UserRole, InvitedUserSnapshot } from "@/lib/types";
|
||||
import type {
|
||||
UserRow,
|
||||
@@ -19,7 +19,6 @@ interface FullUserSnapshot {
|
||||
id: string;
|
||||
email: string;
|
||||
role: UserRole;
|
||||
account_type: AccountType;
|
||||
is_active: boolean;
|
||||
password_configured: boolean;
|
||||
personal_name: string | null;
|
||||
|
||||
@@ -901,6 +901,11 @@ export default function useChatController({
|
||||
});
|
||||
}
|
||||
}
|
||||
// Surface FIFO errors (e.g. 429 before any packets arrive) so the
|
||||
// catch block replaces the thinking placeholder with an error message.
|
||||
if (stack.error) {
|
||||
throw new Error(stack.error);
|
||||
}
|
||||
} catch (e: any) {
|
||||
console.log("Error:", e);
|
||||
const errorMsg = e.message;
|
||||
|
||||
@@ -181,7 +181,7 @@ export const ADMIN_ROUTES = {
|
||||
sidebarLabel: "Users",
|
||||
},
|
||||
API_KEYS: {
|
||||
path: "/admin/api-key",
|
||||
path: "/admin/service-accounts",
|
||||
icon: SvgUserKey,
|
||||
title: "Service Accounts",
|
||||
sidebarLabel: "Service Accounts",
|
||||
|
||||
@@ -52,14 +52,6 @@ export interface UserPersonalization {
|
||||
user_preferences: string;
|
||||
}
|
||||
|
||||
export enum AccountType {
|
||||
STANDARD = "STANDARD",
|
||||
BOT = "BOT",
|
||||
EXT_PERM_USER = "EXT_PERM_USER",
|
||||
SERVICE_ACCOUNT = "SERVICE_ACCOUNT",
|
||||
ANONYMOUS = "ANONYMOUS",
|
||||
}
|
||||
|
||||
export enum UserRole {
|
||||
LIMITED = "limited",
|
||||
BASIC = "basic",
|
||||
@@ -487,7 +479,6 @@ export interface UserGroup {
|
||||
personas: Persona[];
|
||||
is_up_to_date: boolean;
|
||||
is_up_for_deletion: boolean;
|
||||
is_default: boolean;
|
||||
}
|
||||
|
||||
export enum ValidSources {
|
||||
|
||||
@@ -87,7 +87,7 @@ function CreateGroupPage() {
|
||||
const headerActions = (
|
||||
<Section flexDirection="row" gap={0.5} width="auto" height="auto">
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
prominence="secondary"
|
||||
onClick={() => router.push("/admin/groups")}
|
||||
>
|
||||
Cancel
|
||||
@@ -102,7 +102,7 @@ function CreateGroupPage() {
|
||||
);
|
||||
|
||||
return (
|
||||
<SettingsLayouts.Root width="sm">
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgUsers}
|
||||
title="Create Group"
|
||||
|
||||
@@ -287,7 +287,7 @@ function EditGroupPage({ groupId }: EditGroupPageProps) {
|
||||
// 404 state
|
||||
if (!isLoading && !error && !group) {
|
||||
return (
|
||||
<SettingsLayouts.Root width="sm">
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgUsers}
|
||||
title="Group Not Found"
|
||||
@@ -307,7 +307,7 @@ function EditGroupPage({ groupId }: EditGroupPageProps) {
|
||||
const headerActions = (
|
||||
<Section flexDirection="row" gap={0.5} width="auto" height="auto">
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
prominence="secondary"
|
||||
onClick={() => router.push("/admin/groups")}
|
||||
>
|
||||
Cancel
|
||||
@@ -328,7 +328,7 @@ function EditGroupPage({ groupId }: EditGroupPageProps) {
|
||||
|
||||
return (
|
||||
<>
|
||||
<SettingsLayouts.Root width="sm">
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgUsers}
|
||||
title="Edit Group"
|
||||
|
||||
@@ -4,16 +4,14 @@ import type { Route } from "next";
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import useSWR from "swr";
|
||||
import { SvgPlusCircle, SvgUsers } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgUsers } from "@opal/icons";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import type { UserGroup } from "@/lib/types";
|
||||
import { USER_GROUP_URL } from "./svc";
|
||||
import GroupsList from "./GroupsList";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import AdminListHeader from "@/sections/admin/AdminListHeader";
|
||||
import { IllustrationContent } from "@opal/layouts";
|
||||
import SvgNoResult from "@opal/illustrations/no-result";
|
||||
|
||||
@@ -28,34 +26,22 @@ function GroupsPage() {
|
||||
} = useSWR<UserGroup[]>(USER_GROUP_URL, errorHandlingFetcher);
|
||||
|
||||
return (
|
||||
<SettingsLayouts.Root width="sm">
|
||||
{/* This is the sticky header for the groups page. It is used to display
|
||||
* the groups page title and search input when scrolling down.
|
||||
*/}
|
||||
<div
|
||||
className="sticky top-0 z-settings-header bg-background-tint-01"
|
||||
data-testid="groups-page-heading"
|
||||
>
|
||||
<SettingsLayouts.Root>
|
||||
<div data-testid="groups-page-heading">
|
||||
<SettingsLayouts.Header icon={SvgUsers} title="Groups" separator />
|
||||
|
||||
<Section flexDirection="row" padding={1}>
|
||||
<InputTypeIn
|
||||
placeholder="Search groups..."
|
||||
variant="internal"
|
||||
value={searchQuery}
|
||||
leftSearchIcon
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
/>
|
||||
<Button
|
||||
icon={SvgPlusCircle}
|
||||
onClick={() => router.push("/admin/groups/create" as Route)}
|
||||
>
|
||||
New Group
|
||||
</Button>
|
||||
</Section>
|
||||
</div>
|
||||
|
||||
<SettingsLayouts.Body>
|
||||
<AdminListHeader
|
||||
hasItems={!isLoading && !error && (groups?.length ?? 0) > 0}
|
||||
searchQuery={searchQuery}
|
||||
onSearchQueryChange={setSearchQuery}
|
||||
placeholder="Search groups..."
|
||||
emptyStateText="Create groups to organize users and manage access."
|
||||
onAction={() => router.push("/admin/groups/create" as Route)}
|
||||
actionLabel="New Group"
|
||||
/>
|
||||
|
||||
{isLoading && <SimpleLoader />}
|
||||
|
||||
{error && (
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import type { UserGroup } from "@/lib/types";
|
||||
|
||||
/** Whether this group is a system default group (Admin, Basic). */
|
||||
/** Groups that are created by the system and cannot be deleted. */
|
||||
export const BUILT_IN_GROUP_NAMES = ["Basic", "Admin"] as const;
|
||||
|
||||
export function isBuiltInGroup(group: UserGroup): boolean {
|
||||
return group.is_default;
|
||||
return (BUILT_IN_GROUP_NAMES as readonly string[]).includes(group.name);
|
||||
}
|
||||
|
||||
/** Human-readable description for built-in groups. */
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
"use client";
|
||||
|
||||
import { Form, Formik } from "formik";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import {
|
||||
createApiKey,
|
||||
updateApiKey,
|
||||
} from "@/refresh-pages/admin/ServiceAccountsPage/svc";
|
||||
import type { APIKey } from "@/refresh-pages/admin/ServiceAccountsPage/interfaces";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { FormikField } from "@/refresh-components/form/FormikField";
|
||||
import { Vertical as VerticalInput } from "@/layouts/input-layouts";
|
||||
import { USER_ROLE_LABELS, UserRole } from "@/lib/types";
|
||||
import { SvgKey, SvgLock, SvgUser, SvgUserManage } from "@opal/icons";
|
||||
|
||||
interface ApiKeyFormModalProps {
|
||||
onClose: () => void;
|
||||
onCreateApiKey: (apiKey: APIKey) => void;
|
||||
apiKey?: APIKey;
|
||||
}
|
||||
|
||||
export default function ApiKeyFormModal({
|
||||
onClose,
|
||||
onCreateApiKey,
|
||||
apiKey,
|
||||
}: ApiKeyFormModalProps) {
|
||||
const isUpdate = apiKey !== undefined;
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content width="sm" height="lg">
|
||||
<Modal.Header
|
||||
icon={SvgKey}
|
||||
title={isUpdate ? "Update Service Account" : "Create Service Account"}
|
||||
description={
|
||||
isUpdate
|
||||
? undefined
|
||||
: "Use service account API key to programmatically access Onyx API with user-level permissions. You can modify the account details later."
|
||||
}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Formik
|
||||
initialValues={{
|
||||
name: apiKey?.api_key_name || "",
|
||||
role: apiKey?.api_key_role || UserRole.BASIC.toString(),
|
||||
}}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
|
||||
const payload = {
|
||||
...values,
|
||||
role: values.role as UserRole,
|
||||
};
|
||||
|
||||
try {
|
||||
let response;
|
||||
if (isUpdate) {
|
||||
response = await updateApiKey(apiKey.api_key_id, payload);
|
||||
} else {
|
||||
response = await createApiKey(payload);
|
||||
}
|
||||
if (response.ok) {
|
||||
toast.success(
|
||||
isUpdate
|
||||
? "Successfully updated service account!"
|
||||
: "Successfully created service account!"
|
||||
);
|
||||
if (!isUpdate) {
|
||||
onCreateApiKey(await response.json());
|
||||
}
|
||||
onClose();
|
||||
} else {
|
||||
const responseJson = await response.json();
|
||||
const errorMsg = responseJson.detail || responseJson.message;
|
||||
toast.error(
|
||||
isUpdate
|
||||
? `Error updating service account - ${errorMsg}`
|
||||
: `Error creating service account - ${errorMsg}`
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
toast.error(
|
||||
e instanceof Error ? e.message : "An unexpected error occurred."
|
||||
);
|
||||
} finally {
|
||||
formikHelpers.setSubmitting(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, values }) => (
|
||||
<Form className="w-full overflow-visible">
|
||||
<Modal.Body>
|
||||
<VerticalInput
|
||||
name="name"
|
||||
title="Name"
|
||||
nonInteractive
|
||||
sizePreset="main-ui"
|
||||
>
|
||||
<FormikField<string>
|
||||
name="name"
|
||||
render={(field, helper) => (
|
||||
<InputTypeIn
|
||||
{...field}
|
||||
placeholder="Enter a name"
|
||||
onClear={() => helper.setValue("")}
|
||||
showClearButton={false}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
</VerticalInput>
|
||||
|
||||
<VerticalInput
|
||||
name="role"
|
||||
title="Account Permissions"
|
||||
nonInteractive
|
||||
sizePreset="main-ui"
|
||||
>
|
||||
<FormikField<string>
|
||||
name="role"
|
||||
render={(field, helper) => (
|
||||
<InputSelect
|
||||
value={field.value}
|
||||
onValueChange={(value) => helper.setValue(value)}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select permissions" />
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item
|
||||
value={UserRole.ADMIN.toString()}
|
||||
icon={SvgUserManage}
|
||||
description="Unrestricted admin access to all endpoints."
|
||||
>
|
||||
{USER_ROLE_LABELS[UserRole.ADMIN]}
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item
|
||||
value={UserRole.BASIC.toString()}
|
||||
icon={SvgUser}
|
||||
description="Standard user-level access to non-admin endpoints."
|
||||
>
|
||||
{USER_ROLE_LABELS[UserRole.BASIC]}
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item
|
||||
value={UserRole.LIMITED.toString()}
|
||||
icon={SvgLock}
|
||||
description="For agents: chat posting and read-only access to other endpoints."
|
||||
>
|
||||
{USER_ROLE_LABELS[UserRole.LIMITED]}
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
)}
|
||||
/>
|
||||
</VerticalInput>
|
||||
</Modal.Body>
|
||||
|
||||
<Modal.Footer>
|
||||
<Button prominence="secondary" type="button" onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Disabled disabled={isSubmitting || !values.name.trim()}>
|
||||
<Button type="submit">
|
||||
{isUpdate ? "Update" : "Create Account"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
</Modal.Footer>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
461
web/src/refresh-pages/admin/ServiceAccountsPage/index.tsx
Normal file
461
web/src/refresh-pages/admin/ServiceAccountsPage/index.tsx
Normal file
@@ -0,0 +1,461 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo, useState } from "react";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { Button, Text } from "@opal/components";
|
||||
import { Content, IllustrationContent } from "@opal/layouts";
|
||||
import SvgNoResult from "@opal/illustrations/no-result";
|
||||
import {
|
||||
SvgDownload,
|
||||
SvgKey,
|
||||
SvgLock,
|
||||
SvgMoreHorizontal,
|
||||
SvgRefreshCw,
|
||||
SvgTrash,
|
||||
SvgUser,
|
||||
SvgUserEdit,
|
||||
SvgUserKey,
|
||||
SvgUserManage,
|
||||
} from "@opal/icons";
|
||||
import { USER_ROLE_LABELS, UserRole } from "@/lib/types";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import AdminListHeader from "@/sections/admin/AdminListHeader";
|
||||
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
|
||||
import Code from "@/refresh-components/Code";
|
||||
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import { markdown } from "@opal/utils";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
|
||||
import { useBillingInformation } from "@/hooks/useBillingInformation";
|
||||
import { BillingStatus, hasActiveSubscription } from "@/lib/billing/interfaces";
|
||||
import {
|
||||
deleteApiKey,
|
||||
regenerateApiKey,
|
||||
updateApiKey,
|
||||
} from "@/refresh-pages/admin/ServiceAccountsPage/svc";
|
||||
import type { APIKey } from "@/refresh-pages/admin/ServiceAccountsPage/interfaces";
|
||||
import { DISCORD_SERVICE_API_KEY_NAME } from "@/refresh-pages/admin/ServiceAccountsPage/interfaces";
|
||||
import ApiKeyFormModal from "@/refresh-pages/admin/ServiceAccountsPage/ApiKeyFormModal";
|
||||
import { Table } from "@opal/components";
|
||||
import { createTableColumns } from "@opal/components/table/columns";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
const API_KEY_SWR_KEY = "/api/admin/api-key";
|
||||
const route = ADMIN_ROUTES.API_KEYS;
|
||||
|
||||
const tc = createTableColumns<APIKey>();
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Page
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export default function ServiceAccountsPage() {
|
||||
const {
|
||||
data: apiKeys,
|
||||
isLoading,
|
||||
error,
|
||||
} = useSWR<APIKey[]>(API_KEY_SWR_KEY, errorHandlingFetcher);
|
||||
|
||||
const { data: billingData } = useBillingInformation();
|
||||
const isTrialing =
|
||||
billingData !== undefined &&
|
||||
hasActiveSubscription(billingData) &&
|
||||
billingData.status === BillingStatus.TRIALING;
|
||||
|
||||
const [fullApiKey, setFullApiKey] = useState<string | null>(null);
|
||||
const [showCreateUpdateForm, setShowCreateUpdateForm] = useState(false);
|
||||
const [selectedApiKey, setSelectedApiKey] = useState<APIKey | undefined>();
|
||||
const [search, setSearch] = useState("");
|
||||
const [regenerateTarget, setRegenerateTarget] = useState<APIKey | null>(null);
|
||||
const [deleteTarget, setDeleteTarget] = useState<APIKey | null>(null);
|
||||
|
||||
const visibleApiKeys = (apiKeys ?? []).filter(
|
||||
(key) => key.api_key_name !== DISCORD_SERVICE_API_KEY_NAME
|
||||
);
|
||||
|
||||
const filteredApiKeys = visibleApiKeys.filter(
|
||||
(key) =>
|
||||
!search ||
|
||||
(key.api_key_name ?? "").toLowerCase().includes(search.toLowerCase()) ||
|
||||
key.api_key_display.toLowerCase().includes(search.toLowerCase())
|
||||
);
|
||||
|
||||
const handleRoleChange = async (apiKey: APIKey, newRole: UserRole) => {
|
||||
try {
|
||||
const response = await updateApiKey(apiKey.api_key_id, {
|
||||
name: apiKey.api_key_name ?? undefined,
|
||||
role: newRole,
|
||||
});
|
||||
if (!response.ok) {
|
||||
const errorMsg = await response.text();
|
||||
toast.error(`Failed to update role: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
mutate(API_KEY_SWR_KEY);
|
||||
toast.success("Role updated.");
|
||||
} catch {
|
||||
toast.error("Failed to update role.");
|
||||
}
|
||||
};
|
||||
|
||||
const handleRegenerate = async (apiKey: APIKey) => {
|
||||
try {
|
||||
const response = await regenerateApiKey(apiKey);
|
||||
if (!response.ok) {
|
||||
const errorMsg = await response.text();
|
||||
toast.error(`Failed to regenerate API Key: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
const newKey = (await response.json()) as APIKey;
|
||||
setFullApiKey(newKey.api_key);
|
||||
mutate(API_KEY_SWR_KEY);
|
||||
} catch (e) {
|
||||
toast.error(
|
||||
e instanceof Error ? e.message : "Failed to regenerate API Key."
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDelete = async (apiKey: APIKey) => {
|
||||
try {
|
||||
const response = await deleteApiKey(apiKey.api_key_id);
|
||||
if (!response.ok) {
|
||||
const errorMsg = await response.text();
|
||||
toast.error(`Failed to delete API Key: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
mutate(API_KEY_SWR_KEY);
|
||||
} catch (e) {
|
||||
toast.error(e instanceof Error ? e.message : "Failed to delete API Key.");
|
||||
}
|
||||
};
|
||||
|
||||
const columns = useMemo(
|
||||
() => [
|
||||
tc.qualifier({
|
||||
content: "icon",
|
||||
getContent: () => SvgUserKey,
|
||||
}),
|
||||
tc.column("api_key_name", {
|
||||
header: "Name",
|
||||
weight: 25,
|
||||
cell: (value) => (
|
||||
<Content
|
||||
title={value || "Unnamed"}
|
||||
sizePreset="main-ui"
|
||||
variant="body"
|
||||
/>
|
||||
),
|
||||
}),
|
||||
tc.column("api_key_display", {
|
||||
header: "API Key",
|
||||
weight: 30,
|
||||
cell: (value) => (
|
||||
<Text font="secondary-mono" color="text-03">
|
||||
{value}
|
||||
</Text>
|
||||
),
|
||||
}),
|
||||
tc.displayColumn({
|
||||
id: "account_type",
|
||||
header: "Account Type",
|
||||
width: { weight: 25, minWidth: 160 },
|
||||
cell: (row) => (
|
||||
<InputSelect
|
||||
value={row.api_key_role}
|
||||
onValueChange={(value) => handleRoleChange(row, value as UserRole)}
|
||||
>
|
||||
<InputSelect.Trigger />
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item
|
||||
value={UserRole.ADMIN.toString()}
|
||||
icon={SvgUserManage}
|
||||
description="Unrestricted admin access to all endpoints."
|
||||
>
|
||||
{USER_ROLE_LABELS[UserRole.ADMIN]}
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item
|
||||
value={UserRole.BASIC.toString()}
|
||||
icon={SvgUser}
|
||||
description="Standard user-level access to non-admin endpoints."
|
||||
>
|
||||
{USER_ROLE_LABELS[UserRole.BASIC]}
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item
|
||||
value={UserRole.LIMITED.toString()}
|
||||
icon={SvgLock}
|
||||
description="For agents: chat posting and read-only access to other endpoints."
|
||||
>
|
||||
{USER_ROLE_LABELS[UserRole.LIMITED]}
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
),
|
||||
}),
|
||||
tc.actions({
|
||||
cell: (row) => (
|
||||
<div className="flex flex-row gap-1">
|
||||
<Button
|
||||
icon={SvgRefreshCw}
|
||||
prominence="tertiary"
|
||||
tooltip="Regenerate"
|
||||
onClick={() => setRegenerateTarget(row)}
|
||||
/>
|
||||
<Popover>
|
||||
<Popover.Trigger asChild>
|
||||
<Button
|
||||
icon={SvgMoreHorizontal}
|
||||
prominence="tertiary"
|
||||
tooltip="More"
|
||||
/>
|
||||
</Popover.Trigger>
|
||||
<Popover.Content side="bottom" align="end" width="md">
|
||||
<PopoverMenu>
|
||||
<LineItem
|
||||
icon={SvgUserEdit}
|
||||
onClick={() => {
|
||||
setSelectedApiKey(row);
|
||||
setShowCreateUpdateForm(true);
|
||||
}}
|
||||
>
|
||||
Edit Account
|
||||
</LineItem>
|
||||
<LineItem
|
||||
icon={SvgTrash}
|
||||
danger
|
||||
onClick={() => setDeleteTarget(row)}
|
||||
>
|
||||
Delete Account
|
||||
</LineItem>
|
||||
</PopoverMenu>
|
||||
</Popover.Content>
|
||||
</Popover>
|
||||
</div>
|
||||
),
|
||||
}),
|
||||
],
|
||||
[] // eslint-disable-line react-hooks/exhaustive-deps
|
||||
);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
title={route.title}
|
||||
icon={route.icon}
|
||||
description="Use service accounts to programmatically access Onyx API."
|
||||
separator
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<IllustrationContent
|
||||
illustration={SvgNoResult}
|
||||
title="Failed to load service accounts."
|
||||
description="Please check the console for more details."
|
||||
/>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
title={route.title}
|
||||
icon={route.icon}
|
||||
description="Use service accounts to programmatically access Onyx API."
|
||||
separator
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<SimpleLoader />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
const hasKeys = visibleApiKeys.length > 0;
|
||||
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
title={route.title}
|
||||
icon={route.icon}
|
||||
description="Use service accounts to programmatically access Onyx API."
|
||||
separator
|
||||
/>
|
||||
|
||||
<SettingsLayouts.Body>
|
||||
{isTrialing && (
|
||||
<Message
|
||||
static
|
||||
warning
|
||||
close={false}
|
||||
className="w-full"
|
||||
text="Upgrade to a paid plan to create API keys."
|
||||
description="Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col">
|
||||
<AdminListHeader
|
||||
hasItems={hasKeys}
|
||||
searchQuery={search}
|
||||
onSearchQueryChange={setSearch}
|
||||
placeholder="Search service accounts..."
|
||||
emptyStateText="Create service account API keys with user-level access."
|
||||
onAction={() => {
|
||||
setSelectedApiKey(undefined);
|
||||
setShowCreateUpdateForm(true);
|
||||
}}
|
||||
actionLabel="New Service Account"
|
||||
/>
|
||||
|
||||
{hasKeys && (
|
||||
<Table
|
||||
data={filteredApiKeys}
|
||||
getRowId={(row) => String(row.api_key_id)}
|
||||
columns={columns}
|
||||
searchTerm={search}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</SettingsLayouts.Body>
|
||||
|
||||
<Modal open={!!fullApiKey}>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header
|
||||
title="Service Account API Key"
|
||||
icon={SvgKey}
|
||||
onClose={() => setFullApiKey(null)}
|
||||
description="Save this key before continuing. It won't be shown again."
|
||||
/>
|
||||
<Modal.Body>
|
||||
<Code showCopyButton={false}>{fullApiKey ?? ""}</Code>
|
||||
</Modal.Body>
|
||||
<Modal.Footer>
|
||||
<BasicModalFooter
|
||||
left={
|
||||
<Button
|
||||
prominence="secondary"
|
||||
icon={SvgDownload}
|
||||
onClick={() => {
|
||||
if (!fullApiKey) return;
|
||||
const blob = new Blob([fullApiKey], {
|
||||
type: "text/plain",
|
||||
});
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = "onyx-api-key.txt";
|
||||
a.click();
|
||||
URL.revokeObjectURL(url);
|
||||
}}
|
||||
>
|
||||
Download
|
||||
</Button>
|
||||
}
|
||||
submit={
|
||||
// TODO(@raunakab): Create an opalified copy-button and replace it here
|
||||
<Button
|
||||
onClick={() => {
|
||||
if (fullApiKey) {
|
||||
navigator.clipboard.writeText(fullApiKey);
|
||||
toast.success("API key copied to clipboard.");
|
||||
}
|
||||
}}
|
||||
>
|
||||
Copy API Key
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
|
||||
{showCreateUpdateForm && (
|
||||
<ApiKeyFormModal
|
||||
onCreateApiKey={(apiKey) => {
|
||||
setFullApiKey(apiKey.api_key);
|
||||
}}
|
||||
onClose={() => {
|
||||
setShowCreateUpdateForm(false);
|
||||
setSelectedApiKey(undefined);
|
||||
mutate(API_KEY_SWR_KEY);
|
||||
}}
|
||||
apiKey={selectedApiKey}
|
||||
/>
|
||||
)}
|
||||
|
||||
{regenerateTarget && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgRefreshCw}
|
||||
title="Regenerate API Key"
|
||||
onClose={() => setRegenerateTarget(null)}
|
||||
submit={
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={async () => {
|
||||
const target = regenerateTarget;
|
||||
setRegenerateTarget(null);
|
||||
await handleRegenerate(target);
|
||||
}}
|
||||
>
|
||||
Regenerate Key
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`Your current API key *${
|
||||
regenerateTarget.api_key_name || "Unnamed"
|
||||
}* (\`${
|
||||
regenerateTarget.api_key_display
|
||||
}\`) will be revoked and a new key will be generated. You will need to update any applications using this key with the new one.`
|
||||
)}
|
||||
</Text>
|
||||
</ConfirmationModalLayout>
|
||||
)}
|
||||
|
||||
{deleteTarget && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgTrash}
|
||||
title="Delete Account"
|
||||
onClose={() => setDeleteTarget(null)}
|
||||
submit={
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={async () => {
|
||||
await handleDelete(deleteTarget);
|
||||
setDeleteTarget(null);
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
<Section alignItems="start" gap={0.5}>
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`Any application using the API key of account *${
|
||||
deleteTarget.api_key_name || "Unnamed"
|
||||
}* (\`${
|
||||
deleteTarget.api_key_display
|
||||
}\`) will lose access to Onyx.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" color="text-03">
|
||||
Deletion cannot be undone.
|
||||
</Text>
|
||||
</Section>
|
||||
</ConfirmationModalLayout>
|
||||
)}
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
import { UserRole } from "@/lib/types";
|
||||
|
||||
// Discord bot service API key name - should match backend constant
|
||||
export const DISCORD_SERVICE_API_KEY_NAME = "discord-bot-service";
|
||||
|
||||
export interface APIKey {
|
||||
38
web/src/refresh-pages/admin/ServiceAccountsPage/svc.ts
Normal file
38
web/src/refresh-pages/admin/ServiceAccountsPage/svc.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import type {
|
||||
APIKeyArgs,
|
||||
APIKey,
|
||||
} from "@/refresh-pages/admin/ServiceAccountsPage/interfaces";
|
||||
|
||||
const API_KEY_URL = "/api/admin/api-key";
|
||||
|
||||
export async function createApiKey(args: APIKeyArgs): Promise<Response> {
|
||||
return fetch(API_KEY_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(args),
|
||||
});
|
||||
}
|
||||
|
||||
export async function regenerateApiKey(apiKey: APIKey): Promise<Response> {
|
||||
return fetch(`${API_KEY_URL}/${apiKey.api_key_id}/regenerate`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
}
|
||||
|
||||
export async function updateApiKey(
|
||||
apiKeyId: number,
|
||||
args: APIKeyArgs
|
||||
): Promise<Response> {
|
||||
return fetch(`${API_KEY_URL}/${apiKeyId}`, {
|
||||
method: "PATCH",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(args),
|
||||
});
|
||||
}
|
||||
|
||||
export async function deleteApiKey(apiKeyId: number): Promise<Response> {
|
||||
return fetch(`${API_KEY_URL}/${apiKeyId}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
IconProps,
|
||||
OpenAIIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import ProviderCard from "@/sections/cards/ProviderCard";
|
||||
import ProviderCard from "@/sections/admin/ProviderCard";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { FetchError } from "@/lib/fetcher";
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Button } from "@opal/components";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgPlusCircle } from "@opal/icons";
|
||||
interface ActionbarProps {
|
||||
hasActions: boolean;
|
||||
searchQuery?: string;
|
||||
onSearchQueryChange?: (query: string) => void;
|
||||
onAddAction: () => void;
|
||||
className?: string;
|
||||
buttonText?: string;
|
||||
barText?: string;
|
||||
}
|
||||
|
||||
const Actionbar: React.FC<ActionbarProps> = ({
|
||||
hasActions,
|
||||
searchQuery = "",
|
||||
onSearchQueryChange,
|
||||
onAddAction,
|
||||
className,
|
||||
buttonText = "Add MCP Server",
|
||||
barText = "Connect MCP server to add custom actions.",
|
||||
}) => {
|
||||
const handleSearchChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
onSearchQueryChange?.(e.target.value);
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex gap-4 items-center rounded-16",
|
||||
!hasActions ? "bg-background-tint-00 border border-border-01 p-4" : "",
|
||||
className
|
||||
)}
|
||||
>
|
||||
{hasActions ? (
|
||||
<div className="flex-1 min-w-[160px]">
|
||||
<InputTypeIn
|
||||
placeholder="Search servers…"
|
||||
value={searchQuery}
|
||||
onChange={handleSearchChange}
|
||||
leftSearchIcon
|
||||
showClearButton
|
||||
className="w-full !bg-transparent !border-transparent [&:is(:hover,:active,:focus,:focus-within)]:!bg-background-neutral-00 [&:is(:hover,:active,:focus,:focus-within)]:!border-border-01 [&:is(:focus,:focus-within)]:!shadow-none"
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex-1">
|
||||
<Text as="p" mainUiMuted text03>
|
||||
{barText}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex gap-2 items-center justify-end">
|
||||
<Button icon={SvgPlusCircle} onClick={onAddAction}>
|
||||
{buttonText}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
Actionbar.displayName = "Actionbar";
|
||||
export default Actionbar;
|
||||
@@ -3,7 +3,7 @@
|
||||
import { useState, useCallback, useMemo, useEffect } from "react";
|
||||
import { KeyedMutator } from "swr";
|
||||
import MCPActionCard from "@/sections/actions/MCPActionCard";
|
||||
import Actionbar from "@/sections/actions/Actionbar";
|
||||
import AdminListHeader from "@/sections/admin/AdminListHeader";
|
||||
import ActionCardSkeleton from "@/sections/actions/skeleton/ActionCardSkeleton";
|
||||
import { getActionIcon } from "@/lib/tools/mcpUtils";
|
||||
import {
|
||||
@@ -487,13 +487,13 @@ export default function MCPPageContent() {
|
||||
)}
|
||||
|
||||
<div className="flex-shrink-0 mb-4">
|
||||
<Actionbar
|
||||
hasActions={isLoading || mcpServers.length > 0}
|
||||
<AdminListHeader
|
||||
hasItems={isLoading || mcpServers.length > 0}
|
||||
searchQuery={searchQuery}
|
||||
onSearchQueryChange={setSearchQuery}
|
||||
onAddAction={handleAddServer}
|
||||
buttonText="Add MCP Server"
|
||||
barText="Connect MCP server to add custom actions."
|
||||
onAction={handleAddServer}
|
||||
actionLabel="Add MCP Server"
|
||||
emptyStateText="Connect MCP server to add custom actions."
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import OpenAPIAuthenticationModal, {
|
||||
OpenAPIAuthFormValues,
|
||||
} from "./modals/OpenAPIAuthenticationModal";
|
||||
import AddOpenAPIActionModal from "./modals/AddOpenAPIActionModal";
|
||||
import Actionbar from "./Actionbar";
|
||||
import AdminListHeader from "@/sections/admin/AdminListHeader";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import OpenApiActionCard from "./OpenApiActionCard";
|
||||
import { createOAuthConfig, updateOAuthConfig } from "@/lib/oauth/api";
|
||||
@@ -350,13 +350,13 @@ export default function OpenApiPageContent() {
|
||||
)}
|
||||
|
||||
<div className="flex-shrink-0 mb-4">
|
||||
<Actionbar
|
||||
hasActions={isOpenApiLoading || (openApiTools?.length ?? 0) > 0}
|
||||
<AdminListHeader
|
||||
hasItems={isOpenApiLoading || (openApiTools?.length ?? 0) > 0}
|
||||
searchQuery={searchQuery}
|
||||
onSearchQueryChange={setSearchQuery}
|
||||
onAddAction={handleAddAction}
|
||||
buttonText="Add OpenAPI Action"
|
||||
barText="Add custom actions from OpenAPI schemas."
|
||||
onAction={handleAddAction}
|
||||
actionLabel="Add OpenAPI Action"
|
||||
emptyStateText="Add custom actions from OpenAPI schemas."
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
98
web/src/sections/admin/AdminListHeader.tsx
Normal file
98
web/src/sections/admin/AdminListHeader.tsx
Normal file
@@ -0,0 +1,98 @@
|
||||
"use client";
|
||||
|
||||
import { Button, Card } from "@opal/components";
|
||||
import { Content } from "@opal/layouts";
|
||||
import { SvgPlusCircle } from "@opal/icons";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
|
||||
interface AdminListHeaderProps {
|
||||
/** Whether items exist — controls search bar vs empty-state card. */
|
||||
hasItems: boolean;
|
||||
/** Current search query. */
|
||||
searchQuery: string;
|
||||
/** Called when the search query changes. */
|
||||
onSearchQueryChange: (query: string) => void;
|
||||
/** Search input placeholder. */
|
||||
placeholder?: string;
|
||||
/** Text shown in the empty-state card when no items exist. */
|
||||
emptyStateText: string;
|
||||
/** Called when the action button is clicked. */
|
||||
onAction: () => void;
|
||||
/** Label for the action button. */
|
||||
actionLabel: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* AdminListHeader — the top bar for simple admin list pages.
|
||||
*
|
||||
* Handles two states:
|
||||
*
|
||||
* 1. **Items exist** (`hasItems = true`): renders a search input on the left
|
||||
* with a primary action button on the right.
|
||||
* 2. **No items** (`hasItems = false`): renders a bordered card with
|
||||
* descriptive text on the left and the same action button on the right.
|
||||
*
|
||||
* The action button always renders with a `SvgPlusCircle` right icon.
|
||||
*
|
||||
* Used on admin pages that have a flat list of items with no advanced
|
||||
* filtering — e.g. Service Accounts, Groups, OpenAPI Actions, MCP Servers.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <AdminListHeader
|
||||
* hasItems={items.length > 0}
|
||||
* searchQuery={search}
|
||||
* onSearchQueryChange={setSearch}
|
||||
* placeholder="Search service accounts..."
|
||||
* emptyStateText="Create service account API keys with user-level access."
|
||||
* onAction={handleCreate}
|
||||
* actionLabel="New Service Account"
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
export default function AdminListHeader({
|
||||
hasItems,
|
||||
searchQuery,
|
||||
onSearchQueryChange,
|
||||
placeholder = "Search...",
|
||||
emptyStateText,
|
||||
onAction,
|
||||
actionLabel,
|
||||
}: AdminListHeaderProps) {
|
||||
const actionButton = (
|
||||
<Button rightIcon={SvgPlusCircle} onClick={onAction}>
|
||||
{actionLabel}
|
||||
</Button>
|
||||
);
|
||||
|
||||
if (!hasItems) {
|
||||
return (
|
||||
<Card paddingVariant="md" roundingVariant="lg" borderVariant="solid">
|
||||
<div className="flex flex-row items-center justify-between gap-3">
|
||||
<Content
|
||||
title={emptyStateText}
|
||||
sizePreset="main-ui"
|
||||
variant="body"
|
||||
prominence="muted"
|
||||
widthVariant="fit"
|
||||
/>
|
||||
{actionButton}
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-row gap-3 items-center px-2 pb-3">
|
||||
<InputTypeIn
|
||||
variant="internal"
|
||||
leftSearchIcon
|
||||
placeholder={placeholder}
|
||||
value={searchQuery}
|
||||
onChange={(e) => onSearchQueryChange(e.target.value)}
|
||||
showClearButton={false}
|
||||
/>
|
||||
{actionButton}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -11,6 +11,40 @@ import {
|
||||
SvgUnplug,
|
||||
} from "@opal/icons";
|
||||
|
||||
/**
|
||||
* ProviderCard — a stateful card for selecting / connecting / disconnecting
|
||||
* an external service provider (LLM, search engine, voice model, etc.).
|
||||
*
|
||||
* Built on opal `SelectCard` + `CardHeaderLayout`. Maps a three-state
|
||||
* status model to the `SelectCard` state system:
|
||||
*
|
||||
* | Status | SelectCard state | Right action |
|
||||
* |----------------|------------------|------------------------|
|
||||
* | `disconnected` | `empty` | "Connect" button |
|
||||
* | `connected` | `filled` | "Set as Default" button|
|
||||
* | `selected` | `selected` | "Current Default" label|
|
||||
*
|
||||
* Bottom-right actions (Disconnect, Edit) are always visible when the
|
||||
* provider is connected or selected.
|
||||
*
|
||||
* Used on admin configuration pages: Web Search, Image Generation,
|
||||
* Voice, and LLM Configuration.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <ProviderCard
|
||||
* icon={SvgGlobe}
|
||||
* title="Exa"
|
||||
* description="Exa.ai"
|
||||
* status="connected"
|
||||
* onConnect={() => openModal()}
|
||||
* onSelect={() => setDefault(id)}
|
||||
* onEdit={() => openEditModal()}
|
||||
* onDisconnect={() => confirmDisconnect(id)}
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
|
||||
type ProviderStatus = "disconnected" | "connected" | "selected";
|
||||
|
||||
interface ProviderCardProps {
|
||||
@@ -53,18 +53,19 @@ test.describe("Groups page — layout", () => {
|
||||
|
||||
test.beforeAll(async ({ browser }) => {
|
||||
await withApiContext(browser, async (api) => {
|
||||
const groups = await api.getUserGroups();
|
||||
const adminGroup = groups.find((g) => g.name === "Admin" && g.is_default);
|
||||
const basicGroup = groups.find((g) => g.name === "Basic" && g.is_default);
|
||||
if (!adminGroup || !basicGroup) {
|
||||
throw new Error("Default Admin/Basic groups not found");
|
||||
}
|
||||
adminGroupId = adminGroup.id;
|
||||
basicGroupId = basicGroup.id;
|
||||
adminGroupId = await api.createUserGroup("Admin");
|
||||
basicGroupId = await api.createUserGroup("Basic");
|
||||
await api.waitForGroupSync(adminGroupId);
|
||||
await api.waitForGroupSync(basicGroupId);
|
||||
});
|
||||
});
|
||||
|
||||
// No afterAll — these are built-in default groups and must not be deleted
|
||||
test.afterAll(async ({ browser }) => {
|
||||
await withApiContext(browser, async (api) => {
|
||||
await softCleanup(() => api.deleteUserGroup(adminGroupId));
|
||||
await softCleanup(() => api.deleteUserGroup(basicGroupId));
|
||||
});
|
||||
});
|
||||
|
||||
test("renders page title, search, and new group button", async ({
|
||||
groupsPage,
|
||||
@@ -76,8 +77,7 @@ test.describe("Groups page — layout", () => {
|
||||
await expect(groupsPage.newGroupButton).toBeVisible();
|
||||
});
|
||||
|
||||
test.skip("shows built-in groups (Admin, Basic)", async ({ groupsPage }) => {
|
||||
// TODO: Enable once default groups are shown via include_default=true
|
||||
test("shows built-in groups (Admin, Basic)", async ({ groupsPage }) => {
|
||||
await groupsPage.goto();
|
||||
|
||||
await groupsPage.expectGroupVisible("Admin");
|
||||
|
||||
@@ -632,18 +632,6 @@ export class OnyxApiClient {
|
||||
this.log(`Deleted user group: ${groupId}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Lists all user groups.
|
||||
*/
|
||||
async getUserGroups(): Promise<
|
||||
Array<{ id: number; name: string; is_default: boolean }>
|
||||
> {
|
||||
const response = await this.get(
|
||||
"/manage/admin/user-group?include_default=true"
|
||||
);
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async setUserRole(
|
||||
email: string,
|
||||
role: "admin" | "curator" | "global_curator" | "basic",
|
||||
|
||||
Reference in New Issue
Block a user