mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-03 22:12:43 +00:00
Compare commits
20 Commits
feat/resol
...
refactor/d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
22903c9343 | ||
|
|
17b0d19faf | ||
|
|
3b6955468e | ||
|
|
9c091ecd45 | ||
|
|
67c8df002e | ||
|
|
722f7de335 | ||
|
|
df14bbe0e2 | ||
|
|
3db1ad82ce | ||
|
|
1e7882529c | ||
|
|
5d405cfa2d | ||
|
|
de3a253ea9 | ||
|
|
d6946a66a5 | ||
|
|
11835a0268 | ||
|
|
519fb61cc7 | ||
|
|
02671937fb | ||
|
|
1466158c1e | ||
|
|
073cf11c42 | ||
|
|
a2b0c15027 | ||
|
|
a462678ddd | ||
|
|
c50d2739b8 |
108
backend/alembic/versions/03d085c5c38d_backfill_account_type.py
Normal file
108
backend/alembic/versions/03d085c5c38d_backfill_account_type.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""backfill_account_type
|
||||
|
||||
Revision ID: 03d085c5c38d
|
||||
Revises: 977e834c1427
|
||||
Create Date: 2026-03-25 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03d085c5c38d"
|
||||
down_revision = "977e834c1427"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
_STANDARD = "STANDARD"
|
||||
_BOT = "BOT"
|
||||
_EXT_PERM_USER = "EXT_PERM_USER"
|
||||
_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
_ANONYMOUS = "ANONYMOUS"
|
||||
|
||||
# Well-known anonymous user UUID
|
||||
ANONYMOUS_USER_ID = "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
# Email pattern for API key virtual users
|
||||
API_KEY_EMAIL_PATTERN = r"API\_KEY\_\_%"
|
||||
|
||||
# Reflect the table structure for use in DML
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("email", sa.String),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Backfill account_type from role.
|
||||
# Order matters — most-specific matches first so the final catch-all
|
||||
# only touches rows that haven't been classified yet.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# 1a. API key virtual users → SERVICE_ACCOUNT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.email.ilike(API_KEY_EMAIL_PATTERN),
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_SERVICE_ACCOUNT)
|
||||
)
|
||||
|
||||
# 1b. Anonymous user → ANONYMOUS
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.id == ANONYMOUS_USER_ID,
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_ANONYMOUS)
|
||||
)
|
||||
|
||||
# 1c. SLACK_USER role → BOT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "SLACK_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_BOT)
|
||||
)
|
||||
|
||||
# 1d. EXT_PERM_USER role → EXT_PERM_USER
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "EXT_PERM_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_EXT_PERM_USER)
|
||||
)
|
||||
|
||||
# 1e. Everything else → STANDARD
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(user_table.c.account_type.is_(None))
|
||||
.values(account_type=_STANDARD)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Set account_type to NOT NULL now that every row is filled.
|
||||
# ------------------------------------------------------------------
|
||||
op.alter_column(
|
||||
"user",
|
||||
"account_type",
|
||||
nullable=False,
|
||||
server_default="STANDARD",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("user", "account_type", nullable=True, server_default=None)
|
||||
op.execute(sa.update(user_table).values(account_type=None))
|
||||
@@ -0,0 +1,104 @@
|
||||
"""add_effective_permissions
|
||||
|
||||
Adds a JSONB column `effective_permissions` to the user table to store
|
||||
directly granted permissions (e.g. ["admin"] or ["basic"]). Implied
|
||||
permissions are expanded at read time, not stored.
|
||||
|
||||
Backfill: joins user__user_group → permission_grant to collect each
|
||||
user's granted permissions into a JSON array. Users without group
|
||||
memberships keep the default [].
|
||||
|
||||
Revision ID: 503883791c39
|
||||
Revises: b4b7e1028dfd
|
||||
Create Date: 2026-03-30 14:49:22.261748
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "503883791c39"
|
||||
down_revision = "b4b7e1028dfd"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("effective_permissions", postgresql.JSONB),
|
||||
)
|
||||
|
||||
user_user_group = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_id", sa.Uuid),
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"effective_permissions",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'[]'::jsonb"),
|
||||
),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Deduplicated permissions per user
|
||||
deduped = (
|
||||
sa.select(
|
||||
user_user_group.c.user_id,
|
||||
permission_grant.c.permission,
|
||||
)
|
||||
.select_from(
|
||||
user_user_group.join(
|
||||
permission_grant,
|
||||
sa.and_(
|
||||
permission_grant.c.group_id == user_user_group.c.user_group_id,
|
||||
permission_grant.c.is_deleted == sa.false(),
|
||||
),
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
.subquery("deduped")
|
||||
)
|
||||
|
||||
# Aggregate into JSONB array per user (order is not guaranteed;
|
||||
# consumers read this as a set so ordering does not matter)
|
||||
perms_per_user = (
|
||||
sa.select(
|
||||
deduped.c.user_id,
|
||||
sa.func.jsonb_agg(
|
||||
deduped.c.permission,
|
||||
type_=postgresql.JSONB,
|
||||
).label("perms"),
|
||||
)
|
||||
.group_by(deduped.c.user_id)
|
||||
.subquery("sub")
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
user_table.update()
|
||||
.where(user_table.c.id == perms_per_user.c.user_id)
|
||||
.values(effective_permissions=perms_per_user.c.perms)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "effective_permissions")
|
||||
139
backend/alembic/versions/977e834c1427_seed_default_groups.py
Normal file
139
backend/alembic/versions/977e834c1427_seed_default_groups.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""seed_default_groups
|
||||
|
||||
Revision ID: 977e834c1427
|
||||
Revises: 8188861f4e92
|
||||
Create Date: 2026-03-25 14:59:41.313091
|
||||
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "977e834c1427"
|
||||
down_revision = "8188861f4e92"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# (group_name, permission_value)
|
||||
DEFAULT_GROUPS = [
|
||||
("Admin", "admin"),
|
||||
("Basic", "basic"),
|
||||
]
|
||||
|
||||
CUSTOM_SUFFIX = "(Custom)"
|
||||
|
||||
MAX_RENAME_ATTEMPTS = 100
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_up_to_date", sa.Boolean),
|
||||
sa.column("is_up_for_deletion", sa.Boolean),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant_table = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def _find_available_name(conn: sa.engine.Connection, base: str) -> str:
|
||||
"""Return a name like 'Admin (Custom)' or 'Admin (Custom 2)' that is not taken."""
|
||||
candidate = f"{base} {CUSTOM_SUFFIX}"
|
||||
attempt = 1
|
||||
while attempt <= MAX_RENAME_ATTEMPTS:
|
||||
exists: Any = conn.execute(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(user_group_table)
|
||||
.where(user_group_table.c.name == candidate)
|
||||
.limit(1)
|
||||
).fetchone()
|
||||
if exists is None:
|
||||
return candidate
|
||||
attempt += 1
|
||||
candidate = f"{base} (Custom {attempt})"
|
||||
raise RuntimeError(
|
||||
f"Could not find an available name for group '{base}' "
|
||||
f"after {MAX_RENAME_ATTEMPTS} attempts"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
for group_name, permission_value in DEFAULT_GROUPS:
|
||||
# Step 1: Rename ALL existing groups that clash with the canonical name.
|
||||
conflicting = conn.execute(
|
||||
sa.select(user_group_table.c.id, user_group_table.c.name).where(
|
||||
user_group_table.c.name == group_name
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for row_id, row_name in conflicting:
|
||||
new_name = _find_available_name(conn, row_name)
|
||||
op.execute(
|
||||
sa.update(user_group_table)
|
||||
.where(user_group_table.c.id == row_id)
|
||||
.values(name=new_name, is_up_to_date=False)
|
||||
)
|
||||
|
||||
# Step 2: Create a fresh default group.
|
||||
result = conn.execute(
|
||||
user_group_table.insert()
|
||||
.values(
|
||||
name=group_name,
|
||||
is_up_to_date=True,
|
||||
is_up_for_deletion=False,
|
||||
is_default=True,
|
||||
)
|
||||
.returning(user_group_table.c.id)
|
||||
).fetchone()
|
||||
assert result is not None
|
||||
group_id = result[0]
|
||||
|
||||
# Step 3: Upsert permission grant.
|
||||
op.execute(
|
||||
pg_insert(permission_grant_table)
|
||||
.values(
|
||||
group_id=group_id,
|
||||
permission=permission_value,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["group_id", "permission"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the default groups created by this migration.
|
||||
# First remove user-group memberships that reference default groups
|
||||
# to avoid FK violations, then delete the groups themselves.
|
||||
default_group_ids = sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.delete(user__user_group_table).where(
|
||||
user__user_group_table.c.user_group_id.in_(default_group_ids)
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
sa.delete(user_group_table).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
"""grant_basic_to_existing_groups
|
||||
|
||||
Grants the "basic" permission to all existing groups that don't already
|
||||
have it. Every group should have at least "basic" so that its members
|
||||
get basic access when effective_permissions is backfilled.
|
||||
|
||||
Revision ID: b4b7e1028dfd
|
||||
Revises: b7bcc991d722
|
||||
Create Date: 2026-03-30 16:15:17.093498
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b4b7e1028dfd"
|
||||
down_revision = "b7bcc991d722"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_group = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
already_has_basic = (
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(permission_grant)
|
||||
.where(
|
||||
permission_grant.c.group_id == user_group.c.id,
|
||||
permission_grant.c.permission == "basic",
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
|
||||
groups_needing_basic = sa.select(
|
||||
user_group.c.id,
|
||||
sa.literal("basic").label("permission"),
|
||||
sa.literal("SYSTEM").label("grant_source"),
|
||||
sa.literal(False).label("is_deleted"),
|
||||
).where(
|
||||
user_group.c.is_default == sa.false(),
|
||||
~already_has_basic,
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.insert().from_select(
|
||||
["group_id", "permission", "grant_source", "is_deleted"],
|
||||
groups_needing_basic,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
non_default_group_ids = sa.select(user_group.c.id).where(
|
||||
user_group.c.is_default == sa.false()
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.delete().where(
|
||||
permission_grant.c.permission == "basic",
|
||||
permission_grant.c.grant_source == "SYSTEM",
|
||||
permission_grant.c.group_id.in_(non_default_group_ids),
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,125 @@
|
||||
"""assign_users_to_default_groups
|
||||
|
||||
Revision ID: b7bcc991d722
|
||||
Revises: 03d085c5c38d
|
||||
Create Date: 2026-03-25 16:30:39.529301
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7bcc991d722"
|
||||
down_revision = "03d085c5c38d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# The no-auth placeholder user must NOT be assigned to default groups.
|
||||
# A database trigger (migrate_no_auth_data_to_user) will try to DELETE this
|
||||
# user when the first real user registers; group membership rows would cause
|
||||
# an FK violation on that DELETE.
|
||||
NO_AUTH_PLACEHOLDER_USER_UUID = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
sa.column("is_active", sa.Boolean),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Look up default group IDs
|
||||
admin_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Admin",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
basic_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Basic",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
if admin_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Admin' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
if basic_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Basic' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
# Users with role=admin → Admin group
|
||||
# Include inactive users so reactivation doesn't require reconciliation.
|
||||
# Exclude non-human account types (mirrors assign_user_to_default_groups logic).
|
||||
admin_users = sa.select(
|
||||
sa.literal(admin_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.role == "ADMIN",
|
||||
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
|
||||
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], admin_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
# STANDARD users (non-admin) and SERVICE_ACCOUNT users (role=basic) → Basic group
|
||||
# Include inactive users so reactivation doesn't require reconciliation.
|
||||
basic_users = sa.select(
|
||||
sa.literal(basic_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
|
||||
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
|
||||
sa.or_(
|
||||
sa.and_(
|
||||
user_table.c.account_type == "STANDARD",
|
||||
user_table.c.role != "ADMIN",
|
||||
),
|
||||
sa.and_(
|
||||
user_table.c.account_type == "SERVICE_ACCOUNT",
|
||||
user_table.c.role == "BASIC",
|
||||
),
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], basic_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Group memberships are left in place — removing them risks
|
||||
# deleting memberships that existed before this migration.
|
||||
pass
|
||||
@@ -36,13 +36,16 @@ from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -280,7 +283,9 @@ class ScimDAL(DAL):
|
||||
query = (
|
||||
select(User)
|
||||
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
|
||||
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
|
||||
.where(
|
||||
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER])
|
||||
)
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
@@ -521,6 +526,22 @@ class ScimDAL(DAL):
|
||||
self._session.add(group)
|
||||
self._session.flush()
|
||||
|
||||
def add_permission_grant_to_group(
|
||||
self,
|
||||
group_id: int,
|
||||
permission: Permission,
|
||||
grant_source: GrantSource,
|
||||
) -> None:
|
||||
"""Grant a permission to a group and flush."""
|
||||
self._session.add(
|
||||
PermissionGrant(
|
||||
group_id=group_id,
|
||||
permission=permission,
|
||||
grant_source=grant_source,
|
||||
)
|
||||
)
|
||||
self._session.flush()
|
||||
|
||||
def update_group(
|
||||
self,
|
||||
group: UserGroup,
|
||||
|
||||
@@ -19,6 +19,8 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
@@ -28,6 +30,7 @@ from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
@@ -36,6 +39,7 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -255,6 +259,7 @@ def fetch_user_groups(
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -269,6 +274,7 @@ def fetch_user_groups(
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
include_default: If False, excludes system default groups (is_default=True).
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -276,6 +282,8 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -286,6 +294,7 @@ def fetch_user_groups_for_user(
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -295,6 +304,8 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -478,6 +489,16 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
|
||||
# Every group gets the "basic" permission by default
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=db_user_group.id,
|
||||
permission=Permission.BASIC_ACCESS,
|
||||
grant_source=GrantSource.SYSTEM,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
_add_user__user_group_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=db_user_group.id,
|
||||
@@ -489,6 +510,8 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
cc_pair_ids=user_group.cc_pair_ids,
|
||||
)
|
||||
|
||||
recompute_user_permissions__no_commit(user_group.user_ids, db_session)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -796,6 +819,10 @@ def update_user_group(
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
recompute_user_permissions__no_commit(
|
||||
list(set(added_user_ids) | set(removed_user_ids)), db_session
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -835,6 +862,19 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
# Collect affected user IDs before cleanup deletes the relationships
|
||||
affected_user_ids: list[UUID] = [
|
||||
uid
|
||||
for uid in db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
if uid is not None
|
||||
]
|
||||
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
@@ -863,6 +903,10 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
|
||||
# Recompute permissions for affected users now that their
|
||||
# membership in this group has been removed
|
||||
recompute_user_permissions__no_commit(affected_user_ids, db_session)
|
||||
|
||||
db_user_group.is_up_to_date = False
|
||||
db_user_group.is_up_for_deletion = True
|
||||
db_session.commit()
|
||||
|
||||
@@ -52,16 +52,25 @@ from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -486,6 +495,7 @@ def create_user(
|
||||
email=email,
|
||||
hashed_password=_pw_helper.hash(_pw_helper.generate()),
|
||||
role=UserRole.BASIC,
|
||||
account_type=AccountType.STANDARD,
|
||||
is_active=user_resource.active,
|
||||
is_verified=True,
|
||||
personal_name=personal_name,
|
||||
@@ -506,13 +516,25 @@ def create_user(
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
# Assign user to default group BEFORE commit so everything is atomic.
|
||||
# If this fails, the entire user creation rolls back and IdP can retry.
|
||||
try:
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
except Exception:
|
||||
dal.rollback()
|
||||
logger.exception(f"Failed to assign SCIM user {email} to default groups")
|
||||
return _scim_error_response(
|
||||
500, f"Failed to assign user {email} to default group"
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
@@ -542,7 +564,8 @@ def replace_user(
|
||||
user = result
|
||||
|
||||
# Handle activation (need seat check) / deactivation
|
||||
if user_resource.active and not user.is_active:
|
||||
is_reactivation = user_resource.active and not user.is_active
|
||||
if is_reactivation:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
@@ -556,6 +579,12 @@ def replace_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Reconcile default-group membership on reactivation
|
||||
if is_reactivation:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
@@ -621,6 +650,7 @@ def patch_user(
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
# Apply changes back to the DB model
|
||||
is_reactivation = patched.active and not user.is_active
|
||||
if patched.active != user.is_active:
|
||||
if patched.active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
@@ -649,6 +679,12 @@ def patch_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Reconcile default-group membership on reactivation
|
||||
if is_reactivation:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
|
||||
# Build updated fields by merging PATCH enterprise data with current values
|
||||
cf = current_fields or ScimMappingFields()
|
||||
fields = ScimMappingFields(
|
||||
@@ -857,6 +893,11 @@ def create_group(
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
if group_resource.displayName in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(
|
||||
409, f"'{group_resource.displayName}' is a reserved group name."
|
||||
)
|
||||
|
||||
if dal.get_group_by_name(group_resource.displayName):
|
||||
return _scim_error_response(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
@@ -879,8 +920,18 @@ def create_group(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
)
|
||||
|
||||
# Every group gets the "basic" permission by default.
|
||||
dal.add_permission_grant_to_group(
|
||||
group_id=db_group.id,
|
||||
permission=Permission.BASIC_ACCESS,
|
||||
grant_source=GrantSource.SYSTEM,
|
||||
)
|
||||
|
||||
dal.upsert_group_members(db_group.id, member_uuids)
|
||||
|
||||
# Recompute permissions for initial members.
|
||||
recompute_user_permissions__no_commit(member_uuids, db_session)
|
||||
|
||||
external_id = group_resource.externalId
|
||||
if external_id:
|
||||
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
|
||||
@@ -911,14 +962,36 @@ def replace_group(
|
||||
return result
|
||||
group = result
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES and group_resource.displayName != group.name:
|
||||
return _scim_error_response(
|
||||
409, f"'{group.name}' is a reserved group name and cannot be renamed."
|
||||
)
|
||||
|
||||
if (
|
||||
group_resource.displayName in _RESERVED_GROUP_NAMES
|
||||
and group_resource.displayName != group.name
|
||||
):
|
||||
return _scim_error_response(
|
||||
409, f"'{group_resource.displayName}' is a reserved group name."
|
||||
)
|
||||
|
||||
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
|
||||
if err:
|
||||
return _scim_error_response(400, err)
|
||||
|
||||
# Capture old member IDs before replacing so we can recompute their
|
||||
# permissions after they are removed from the group.
|
||||
old_member_ids = {uid for uid, _ in dal.get_group_members(group.id)}
|
||||
|
||||
dal.update_group(group, name=group_resource.displayName)
|
||||
dal.replace_group_members(group.id, member_uuids)
|
||||
dal.sync_group_external_id(group.id, group_resource.externalId)
|
||||
|
||||
# Recompute permissions for current members (batch) and removed members.
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
removed_ids = list(old_member_ids - set(member_uuids))
|
||||
recompute_user_permissions__no_commit(removed_ids, db_session)
|
||||
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
@@ -961,8 +1034,19 @@ def patch_group(
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
new_name = patched.displayName if patched.displayName != group.name else None
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES and new_name:
|
||||
return _scim_error_response(
|
||||
409, f"'{group.name}' is a reserved group name and cannot be renamed."
|
||||
)
|
||||
|
||||
if new_name and new_name in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(409, f"'{new_name}' is a reserved group name.")
|
||||
|
||||
dal.update_group(group, name=new_name)
|
||||
|
||||
affected_uuids: list[UUID] = []
|
||||
|
||||
if added_ids:
|
||||
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
|
||||
if add_uuids:
|
||||
@@ -973,10 +1057,15 @@ def patch_group(
|
||||
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
|
||||
)
|
||||
dal.upsert_group_members(group.id, add_uuids)
|
||||
affected_uuids.extend(add_uuids)
|
||||
|
||||
if removed_ids:
|
||||
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
|
||||
dal.remove_group_members(group.id, remove_uuids)
|
||||
affected_uuids.extend(remove_uuids)
|
||||
|
||||
# Recompute permissions for all users whose group membership changed.
|
||||
recompute_user_permissions__no_commit(affected_uuids, db_session)
|
||||
|
||||
dal.sync_group_external_id(group.id, patched.externalId)
|
||||
dal.commit()
|
||||
@@ -1002,11 +1091,21 @@ def delete_group(
|
||||
return result
|
||||
group = result
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(409, f"'{group.name}' is a reserved group name.")
|
||||
|
||||
# Capture member IDs before deletion so we can recompute their permissions.
|
||||
affected_user_ids = [uid for uid, _ in dal.get_group_members(group.id)]
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
if mapping:
|
||||
dal.delete_group_mapping(mapping.id)
|
||||
|
||||
dal.delete_group_with_members(group)
|
||||
|
||||
# Recompute permissions for users who lost this group membership.
|
||||
recompute_user_permissions__no_commit(affected_user_ids, db_session)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -43,12 +43,16 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
def list_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
@@ -56,27 +60,50 @@ def list_user_groups(
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@router.get("/user-groups/minimal")
|
||||
def list_minimal_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserGroupSnapshot]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
include_default=include_default,
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [
|
||||
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
|
||||
]
|
||||
|
||||
|
||||
@router.get("/admin/user-group/{user_group_id}/permissions")
|
||||
def get_user_group_permissions(
|
||||
user_group_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[str]:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
|
||||
return [
|
||||
grant.permission.value
|
||||
for grant in group.permission_grants
|
||||
if not grant.is_deleted
|
||||
]
|
||||
|
||||
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
@@ -100,6 +127,9 @@ def rename_user_group_endpoint(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
group = fetch_user_group(db_session, rename_request.id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot rename a default system group.")
|
||||
try:
|
||||
return UserGroup.from_model(
|
||||
rename_user_group(
|
||||
@@ -185,6 +215,9 @@ def delete_user_group(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot delete a default system group.")
|
||||
try:
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -22,6 +22,7 @@ class UserGroup(BaseModel):
|
||||
personas: list[PersonaSnapshot]
|
||||
is_up_to_date: bool
|
||||
is_up_for_deletion: bool
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
|
||||
@@ -74,18 +75,21 @@ class UserGroup(BaseModel):
|
||||
],
|
||||
is_up_to_date=user_group_model.is_up_to_date,
|
||||
is_up_for_deletion=user_group_model.is_up_for_deletion,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
class MinimalUserGroupSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
|
||||
return cls(
|
||||
id=user_group_model.id,
|
||||
name=user_group_model.name,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
|
||||
110
backend/onyx/auth/permissions.py
Normal file
110
backend/onyx/auth/permissions.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Permission resolution for group-based authorization.
|
||||
|
||||
Granted permissions are stored as a JSONB column on the User table and
|
||||
loaded for free with every auth query. Implied permissions are expanded
|
||||
at read time — only directly granted permissions are persisted.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ALL_PERMISSIONS: frozenset[str] = frozenset(p.value for p in Permission)
|
||||
|
||||
# Implication map: granted permission -> set of permissions it implies.
|
||||
IMPLIED_PERMISSIONS: dict[str, set[str]] = {
|
||||
Permission.ADD_AGENTS.value: {Permission.READ_AGENTS.value},
|
||||
Permission.MANAGE_AGENTS.value: {
|
||||
Permission.ADD_AGENTS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
},
|
||||
Permission.MANAGE_DOCUMENT_SETS.value: {
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value},
|
||||
Permission.MANAGE_CONNECTORS.value: {
|
||||
Permission.ADD_CONNECTORS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.MANAGE_USER_GROUPS.value: {
|
||||
Permission.READ_CONNECTORS.value,
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
Permission.READ_USERS.value,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def resolve_effective_permissions(granted: set[str]) -> set[str]:
|
||||
"""Expand granted permissions with their implied permissions.
|
||||
|
||||
If "admin" is present, returns all 19 permissions.
|
||||
"""
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS.value in granted:
|
||||
return set(ALL_PERMISSIONS)
|
||||
|
||||
effective = set(granted)
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for perm in list(effective):
|
||||
implied = IMPLIED_PERMISSIONS.get(perm)
|
||||
if implied and not implied.issubset(effective):
|
||||
effective |= implied
|
||||
changed = True
|
||||
return effective
|
||||
|
||||
|
||||
def get_effective_permissions(user: User) -> set[Permission]:
|
||||
"""Read granted permissions from the column and expand implied permissions."""
|
||||
granted: set[Permission] = set()
|
||||
for p in user.effective_permissions:
|
||||
try:
|
||||
granted.add(Permission(p))
|
||||
except ValueError:
|
||||
logger.warning(f"Skipping unknown permission '{p}' for user {user.id}")
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in granted:
|
||||
return set(Permission)
|
||||
expanded = resolve_effective_permissions({p.value for p in granted})
|
||||
return {Permission(p) for p in expanded}
|
||||
|
||||
|
||||
def require_permission(
|
||||
required: Permission,
|
||||
) -> Callable[..., Coroutine[Any, Any, User]]:
|
||||
"""FastAPI dependency factory for permission-based access control.
|
||||
|
||||
Usage:
|
||||
@router.get("/endpoint")
|
||||
def endpoint(user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS))):
|
||||
...
|
||||
"""
|
||||
|
||||
async def dependency(user: User = Depends(current_user)) -> User:
|
||||
effective = get_effective_permissions(user)
|
||||
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in effective:
|
||||
return user
|
||||
|
||||
if required not in effective:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"You do not have the required permissions for this action.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
return dependency
|
||||
@@ -5,6 +5,8 @@ from typing import Any
|
||||
from fastapi_users import schemas
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""
|
||||
@@ -41,6 +43,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
account_type: AccountType = AccountType.STANDARD
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
# Excluded from create_update_dict so it never reaches the DB layer
|
||||
@@ -50,19 +53,19 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
def create_update_dict(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict()
|
||||
d.pop("captcha_token", None)
|
||||
# Force STANDARD for self-registration; only trusted paths
|
||||
# (SCIM, API key creation) supply a different account_type directly.
|
||||
d["account_type"] = AccountType.STANDARD
|
||||
return d
|
||||
|
||||
@override
|
||||
def create_update_dict_superuser(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict_superuser()
|
||||
d.pop("captcha_token", None)
|
||||
d.setdefault("account_type", self.account_type)
|
||||
return d
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
|
||||
@@ -80,7 +80,6 @@ from onyx.auth.pat import get_hashed_pat_from_request
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdateWithRole
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
@@ -120,11 +119,13 @@ from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
@@ -500,18 +501,21 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user = user_by_session
|
||||
|
||||
if (
|
||||
user.role.is_web_login()
|
||||
user.account_type.is_web_login()
|
||||
or not isinstance(user_create, UserCreate)
|
||||
or not user_create.role.is_web_login()
|
||||
or not user_create.account_type.is_web_login()
|
||||
):
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
# Cache id before expire — accessing attrs on an expired
|
||||
# object triggers a sync lazy-load which raises MissingGreenlet
|
||||
# in this async context.
|
||||
user_id = user.id
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create)
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user_id) # type: ignore[assignment]
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
|
||||
@@ -525,18 +529,21 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
user.role.is_web_login()
|
||||
user.account_type.is_web_login()
|
||||
or not isinstance(user_create, UserCreate)
|
||||
or not user_create.role.is_web_login()
|
||||
or not user_create.account_type.is_web_login()
|
||||
):
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
# Cache id before expire — accessing attrs on an expired
|
||||
# object triggers a sync lazy-load which raises MissingGreenlet
|
||||
# in this async context.
|
||||
user_id = user.id
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create)
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user_id) # type: ignore[assignment]
|
||||
if user_created:
|
||||
await self._assign_default_pinned_assistants(user, db_session)
|
||||
remove_user_from_invited_users(user_create.email)
|
||||
@@ -573,6 +580,38 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
user.pinned_assistants = default_persona_ids
|
||||
|
||||
def _upgrade_user_to_standard__sync(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
user_create: UserCreate,
|
||||
) -> None:
|
||||
"""Upgrade a non-web user to STANDARD and assign default groups atomically.
|
||||
|
||||
All writes happen in a single sync transaction so neither the field
|
||||
update nor the group assignment is visible without the other.
|
||||
"""
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user_id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.hashed_password = self.password_helper.hash(
|
||||
user_create.password
|
||||
)
|
||||
sync_user.is_verified = user_create.is_verified or False
|
||||
sync_user.role = user_create.role
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
assign_user_to_default_groups__no_commit(
|
||||
sync_db,
|
||||
sync_user,
|
||||
is_admin=(user_create.role == UserRole.ADMIN),
|
||||
)
|
||||
sync_db.commit()
|
||||
else:
|
||||
logger.warning(
|
||||
"User %s not found in sync session during upgrade to standard; "
|
||||
"skipping upgrade",
|
||||
user_id,
|
||||
)
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
# Validate password according to configurable security policy (defined via environment variables)
|
||||
if len(password) < PASSWORD_MIN_LENGTH:
|
||||
@@ -694,6 +733,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"email": account_email,
|
||||
"hashed_password": self.password_helper.hash(password),
|
||||
"is_verified": is_verified_by_default,
|
||||
"account_type": AccountType.STANDARD,
|
||||
}
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
@@ -726,7 +766,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
# We must use the existing user in the session if it matches
|
||||
# the user we just got by email/oauth. Note that this only applies
|
||||
# to multi-tenant, due to the overwriting of the user_db
|
||||
@@ -743,14 +783,25 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
**({"is_active": True} if not user.is_active else {}),
|
||||
},
|
||||
)
|
||||
# Upgrade the user and assign default groups in a single
|
||||
# transaction so neither change is visible without the other.
|
||||
was_inactive = not user.is_active
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.is_verified = is_verified_by_default
|
||||
sync_user.role = UserRole.BASIC
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
if was_inactive:
|
||||
sync_user.is_active = True
|
||||
assign_user_to_default_groups__no_commit(sync_db, sync_user)
|
||||
sync_db.commit()
|
||||
|
||||
# Refresh the async user object so downstream code
|
||||
# (e.g. oidc_expiry check) sees the updated fields.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user.id)
|
||||
assert user is not None
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
@@ -836,6 +887,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
event=MilestoneRecordType.TENANT_CREATED,
|
||||
)
|
||||
|
||||
# Assign user to the appropriate default group (Admin or Basic).
|
||||
# Must happen inside the try block while tenant context is active,
|
||||
# otherwise get_session_with_current_tenant() targets the wrong schema.
|
||||
is_admin = user_count == 1 or user.email in get_default_admin_user_emails()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=is_admin
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
@@ -975,7 +1036,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
raise BasicAuthenticationError(
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
@@ -1471,7 +1532,7 @@ async def _get_or_create_user_from_jwt(
|
||||
if not user.is_active:
|
||||
logger.warning("Inactive user %s attempted JWT login; skipping", email)
|
||||
return None
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
raise exceptions.UserNotExists()
|
||||
except exceptions.UserNotExists:
|
||||
logger.info("Provisioning user %s from JWT login", email)
|
||||
@@ -1492,7 +1553,7 @@ async def _get_or_create_user_from_jwt(
|
||||
email,
|
||||
)
|
||||
return None
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
logger.warning(
|
||||
"Non-web-login user %s attempted JWT login during provisioning race; skipping",
|
||||
email,
|
||||
@@ -1554,6 +1615,7 @@ def get_anonymous_user() -> User:
|
||||
is_verified=True,
|
||||
is_superuser=False,
|
||||
role=UserRole.LIMITED,
|
||||
account_type=AccountType.ANONYMOUS,
|
||||
use_memories=False,
|
||||
enable_memory_tool=False,
|
||||
)
|
||||
|
||||
@@ -1,25 +1,33 @@
|
||||
# Overview of Context Management
|
||||
|
||||
This document reviews some design decisions around the main agent-loop powering Onyx's chat flow.
|
||||
It is highly recommended for all engineers contributing to this flow to be familiar with the concepts here.
|
||||
|
||||
> Note: it is assumed the reader is familiar with the Onyx product and features such as Projects, User files, Citations, etc.
|
||||
|
||||
## System Prompt
|
||||
|
||||
The system prompt is a default prompt that comes packaged with the system. Users can edit the default prompt and it will be persisted in the database.
|
||||
|
||||
Some parts of the system prompt are dynamically updated / inserted:
|
||||
|
||||
- Datetime of the message sent
|
||||
- Tools description of when to use certain tools depending on if the tool is available in that cycle
|
||||
- If the user has just called a search related tool, then a section about citations is included
|
||||
|
||||
|
||||
## Custom Agent Prompt
|
||||
|
||||
The custom agent is inserted as a user message above the most recent user message, it is dynamically moved in the history as the user sends more messages.
|
||||
If the user has opted to completely replace the System Prompt, then this Custom Agent prompt replaces the system prompt and does not move along the history.
|
||||
|
||||
|
||||
## How Files are handled
|
||||
|
||||
On upload, Files are processed for tokens, if too many tokens to fit in the context, it’s considered a failed inclusion. This is done using the LLM tokenizer.
|
||||
|
||||
- In many cases, there is not a known tokenizer for each LLM so there is a default tokenizer used as a catchall.
|
||||
- File upload happens in 2 parts - the actual upload + token counting.
|
||||
- Files are added into chat context as a “point in time” inclusion and move up the context window as the conversation progresses.
|
||||
Every file knows how many tokens it is (model agnostic), image files have some assumed number of tokens.
|
||||
Every file knows how many tokens it is (model agnostic), image files have some assumed number of tokens.
|
||||
|
||||
Image files are attached to User Messages also as point in time inclusions.
|
||||
|
||||
@@ -27,8 +35,8 @@ Image files are attached to User Messages also as point in time inclusions.
|
||||
Files selected from the search results are also counted as “point in time” inclusions. Files that are too large cannot be selected.
|
||||
For these files, the "entire file" does not exist for most connectors, it's pieced back together from the search engine.
|
||||
|
||||
|
||||
## Projects
|
||||
|
||||
If a Project contains few enough files that it all fits in the model context, we keep it close enough in the history to ensure it is easy for the LLM to
|
||||
access. Note that the project documents are assumed to be quite useful and that they should 1. never be dropped from context, 2. is not just a needle in
|
||||
a haystack type search with a strong keyword to make the LLM attend to it.
|
||||
@@ -36,11 +44,12 @@ a haystack type search with a strong keyword to make the LLM attend to it.
|
||||
Project files are vectorized and stored in the Search Engine so that if the user chooses a model with less context than the number of tokens in the project,
|
||||
the system can RAG over the project files.
|
||||
|
||||
|
||||
## How documents are represented
|
||||
Documents from search or uploaded Project files are represented as a json so that the LLM can easily understand it. It is represented with a prefix to make the
|
||||
context clearer to the LLM. Note that for search results (whether web or internal, it will just be the json) and it will be a Tool Call type of message
|
||||
rather than a user message.
|
||||
|
||||
Documents from search or uploaded Project files are represented as a json so that the LLM can easily understand it. It is represented with a prefix string to
|
||||
make the context clearer to the LLM. Note that for search results (whether web or internal, it will just be the json) and it will be a Tool Call type of
|
||||
message rather than a user message.
|
||||
|
||||
```
|
||||
Here are some documents provided for context, they may not all be relevant:
|
||||
{
|
||||
@@ -50,33 +59,37 @@ Here are some documents provided for context, they may not all be relevant:
|
||||
]
|
||||
}
|
||||
```
|
||||
Documents are represented with document so that the LLM can easily cite them with a single number. The tool returns have to be richer to be able to
|
||||
|
||||
Documents are represented with the `document` key so that the LLM can easily cite them with a single number. The tool returns have to be richer to be able to
|
||||
translate this into links and other UI elements. What the LLM sees is far simpler to reduce noise/hallucinations.
|
||||
|
||||
Note that documents included in a single turn should be collapsed into a single user message.
|
||||
|
||||
Search tools give URLs to the LLM though so that open_url (a separate tool) can be called on them.
|
||||
|
||||
Search tools also give URLs to the LLM so that open_url (a separate tool) can be called on them.
|
||||
|
||||
## Reminders
|
||||
|
||||
To ensure the LLM follows certain specific instructions, instructions are added at the very end of the chat context as a user message. If a search related
|
||||
tool is used, a citation reminder is always added. Otherwise, by default there is no reminder. If the user configures reminders, those are added to the
|
||||
final message. If a search related tool just ran and the user has reminders, both appear in a single message.
|
||||
|
||||
If a search related tool is called at any point during the turn, the reminder will remain at the end until the turn is over and the agent has responded.
|
||||
|
||||
|
||||
## Tool Calls
|
||||
As tool call responses can get very long (like an internal search can be many thousands of tokens), tool responses are today replaced with a hardcoded
|
||||
|
||||
As tool call responses can get very long (like an internal search can be many thousands of tokens), tool responses are current replaced with a hardcoded
|
||||
string saying it is no longer available. Tool Call details like the search query and other arguments are kept in the history as this is information
|
||||
rich and generally very few tokens.
|
||||
|
||||
> Note: in the Internal Search flow with query expansion, the Tool Call which was actually run differs from what the LLM provided as arguments.
|
||||
> What the LLM sees in the history (to be most informative for future calls) is the full set of expanded queries.
|
||||
|
||||
**Possible Future Extension**:
|
||||
Instead of dropping the Tool Call response, we might summarize it using an LLM so that it is just 1-2 sentences and captures the main points. That said,
|
||||
this is questionable value add because anything relevant and useful should be already captured in the Agent response.
|
||||
|
||||
|
||||
## Examples
|
||||
|
||||
```
|
||||
S -> System Message
|
||||
CA -> Custom Agent as a User Message
|
||||
@@ -98,15 +111,15 @@ Flow with Project and File Upload
|
||||
S, CA, P, F, U1, A1 -- user sends another message -> S, F, U1, A1, CA, P, U2, A2
|
||||
- File stays in place, above the user message
|
||||
- Project files move along the chain as new messages are sent
|
||||
- Custom Agent prompt comes before project files which comes before user uploaded files in each turn
|
||||
- Custom Agent prompt comes before project files which come before user uploaded files in each turn
|
||||
|
||||
Reminders during a single Turn
|
||||
S, U1, TC, TR, R -- agent calls another tool -> S, U1, TC, TR, TC, TR, R, A1
|
||||
- Reminder moved to the end
|
||||
```
|
||||
|
||||
|
||||
## Product considerations
|
||||
|
||||
Project files are important to the entire duration of the chat session. If the user has uploaded project files, they are likely very intent on working with
|
||||
those files. The LLM is much better at referencing documents close to the end of the context window so keeping it there for ease of access.
|
||||
|
||||
@@ -117,9 +130,9 @@ User Message further away. This tradeoff is accepted for Projects because of the
|
||||
Reminder are absolutely necessary to ensure 1-2 specific instructions get followed with a very high probability. It is less detailed than the system prompt
|
||||
and should be very targetted for it to work reliably and also not interfere with the last user message.
|
||||
|
||||
|
||||
## Reasons / Experiments
|
||||
Custom Agent instructions being placed in the system prompt is poorly followed. It also degrade performance of the system especially when the instructions
|
||||
|
||||
Custom Agent instructions being placed in the system prompt is poorly followed. It also degrades performance of the system especially when the instructions
|
||||
are orthogonal (or even possibly contradictory) to the system prompt. For weaker models, it causes strange artifacts in tool calls and final responses
|
||||
that completely ruins the user experience. Empirically, this way works better across a range of models especially when the history gets longer.
|
||||
Having the Custom Agent instructions not move means it fades more as the chat gets long which is also not ok from a UX perspective.
|
||||
@@ -146,10 +159,10 @@ In a similar concept, LLM instructions in the system prompt are structured speci
|
||||
fairly surprising actually but if there is a line of instructions effectively saying "If you try to use some tools and find that you need more information or
|
||||
need to call additional tools, you are encouraged to do this", having this in the Tool section of the System prompt makes all the LLMs follow it well but if it's
|
||||
even just a paragraph away like near the beginning of the prompt, it is often ignored. The difference is as drastic as a 30% follow rate to a 90% follow
|
||||
rate even just moving the same statement a few sentences.
|
||||
|
||||
rate by even just moving the same statement a few sentences.
|
||||
|
||||
## Other related pointers
|
||||
|
||||
- How messages, files, images are stored can be found in backend/onyx/db/models.py, there is also a README.md under that directory that may be helpful.
|
||||
|
||||
---
|
||||
@@ -160,32 +173,38 @@ rate even just moving the same statement a few sentences.
|
||||
Turn: User sends a message and AI does some set of things and responds
|
||||
Step/Cycle: 1 single LLM inference given some context and some tools
|
||||
|
||||
|
||||
## 1. Top Level (process_message function):
|
||||
|
||||
This function can be thought of as the set-up and validation layer. It ensures that the database is in a valid state, reads the
|
||||
messages in the session and sets up all the necessary items to run the chat loop and state containers. The major things it does
|
||||
are:
|
||||
|
||||
- Validates the request
|
||||
- Builds the chat history for the session
|
||||
- Fetches any additional context such as files and images
|
||||
- Prepares all of the tools for the LLM
|
||||
- Creates the state container objects for use in the loop
|
||||
|
||||
### Wrapper (run_chat_loop_with_state_containers function):
|
||||
This wrapper is used to run the LLM flow in a background thread and monitor the emitter for stop signals. This means the top
|
||||
level is as isolated from the LLM flow as possible and can continue to yield packets as soon as they are available from the lower
|
||||
levels. This also means that if the lower levels fail, the top level will still guarantee a reasonable response to the user.
|
||||
All of the saving and database operations are abstracted away from the lower levels.
|
||||
### Execution (`_run_models` function):
|
||||
|
||||
Each model runs in its own worker thread inside a `ThreadPoolExecutor`. Workers write packets to a shared
|
||||
`merged_queue` via an `Emitter`; the main thread drains the queue and yields packets in arrival order. This
|
||||
means the top level is isolated from the LLM flow and can yield packets as soon as they are produced. If a
|
||||
worker fails, the main thread yields a `StreamingError` for that model and keeps the other models running.
|
||||
All saving and database operations are handled by the main thread after the workers complete (or by the
|
||||
workers themselves via self-completion if the drain loop exits early).
|
||||
|
||||
### Emitter
|
||||
The emitter is designed to be an object queue so that lower levels do not need to yield objects all the way back to the top.
|
||||
This way the functions can be better designed (not everything as a generator) and more easily tested. The wrapper around the
|
||||
LLM flow (run_chat_loop_with_state_containers) is used to monitor the emitter and handle packets as soon as they are available
|
||||
from the lower levels. Both the emitter and the state container are mutating state objects and only used to accumulate state.
|
||||
There should be no logic dependent on the states of these objects, especially in the lower levels. The emitter should only take
|
||||
packets and should not be used for other things.
|
||||
|
||||
The emitter is an object that lower levels use to send packets without needing to yield them all the way back
|
||||
up the call stack. Each `Emitter` tags every packet with a `model_index` and places it on the shared
|
||||
`merged_queue` as a `(model_idx, packet)` tuple. The drain loop in `_run_models` consumes these tuples and
|
||||
yields the packets to the caller. Both the emitter and the state container are mutating state objects used
|
||||
only to accumulate state. There should be no logic dependent on the states of these objects, especially in
|
||||
the lower levels. The emitter should only take packets and should not be used for other things.
|
||||
|
||||
### State Container
|
||||
|
||||
The state container is used to accumulate state during the LLM flow. Similar to the emitter, it should not be used for logic,
|
||||
only for accumulating state. It is used to gather all of the necessary information for saving the chat turn into the database.
|
||||
So it will accumulate answer tokens, reasoning tokens, tool calls, citation info, etc. This is used at the end of the flow once
|
||||
@@ -193,35 +212,40 @@ the lower level is completed whether on its own or stopped by the user. At that
|
||||
the database. The state container can be added to by any of the underlying layers, this is fine.
|
||||
|
||||
### Stopping Generation
|
||||
A stop signal is checked every 300ms by the wrapper around the LLM flow. The signal itself
|
||||
is stored in Redis and is set by the user calling the stop endpoint. The wrapper ensures that no matter what the lower level is
|
||||
doing at the time, the thread can be killed by the top level. It does not require a cooperative cancellation from the lower level
|
||||
and in fact the lower level does not know about the stop signal at all.
|
||||
|
||||
The drain loop in `_run_models` checks `check_is_connected()` every 50 ms (on queue timeout). The signal itself
|
||||
is stored in Redis and is set by the user calling the stop endpoint. On disconnect, the drain loop saves
|
||||
partial state for every model, yields an `OverallStop(stop_reason="user_cancelled")` packet, and returns.
|
||||
A `drain_done` event signals emitters to stop blocking so worker threads can exit quickly. Workers that
|
||||
already completed successfully will self-complete (persist their response) if the drain loop exited before
|
||||
reaching the normal completion path.
|
||||
|
||||
## 2. LLM Loop (run_llm_loop function)
|
||||
|
||||
This function handles the logic of the Turn. It's essentially a while loop where context is added and modified (according what
|
||||
is outlined in the first half of this doc). Its main functionality is:
|
||||
|
||||
- Translate and truncate the context for the LLM inference
|
||||
- Add context modifiers like reminders, updates to the system prompts, etc.
|
||||
- Run tool calls and gather results
|
||||
- Build some of the objects stored in the state container.
|
||||
|
||||
|
||||
## 3. LLM Step (run_llm_step function)
|
||||
|
||||
This function is a single inference of the LLM. It's a wrapper around the LLM stream function which handles packet translations
|
||||
so that the Emitter can emit individual tokens as soon as they arrive. It also keeps track of the different sections since they
|
||||
do not all come at once (reasoning, answers, tool calls are all built up token by token). This layer also tracks the different
|
||||
tool calls and returns that to the LLM Loop to execute.
|
||||
|
||||
|
||||
## Things to know
|
||||
- Packets are labeled with a "turn_index" field as part of the Placement of the packet. This is not the same as the backend
|
||||
concept of a turn. The turn_index for the frontend is which block does this packet belong to. So while a reasoning + tool call
|
||||
comes from the same LLM inference (same backend LLM step), they are 2 turns to the frontend because that's how it's rendered.
|
||||
|
||||
- There are 3 representations of "message". The first is the database model ChatMessage, this one should be translated away and
|
||||
not used deep into the flow. The second is ChatMessageSimple which is the data model which should be used throughout the code
|
||||
as much as possible. If modifications/additions are needed, it should be to this object. This is the rich representation of a
|
||||
message for the code. Finally there is the LanguageModelInput representation of a message. This one is for the LLM interface
|
||||
layer and is as stripped down as possible so that the LLM interface can be clean and easy to maintain/extend.
|
||||
- Packets are labeled with a "turn_index" field as part of the Placement of the packet. This is not the same as the backend
|
||||
concept of a turn. The turn_index for the frontend is which block does this packet belong to. So while a reasoning + tool call
|
||||
comes from the same LLM inference (same backend LLM step), they are 2 turns to the frontend because that's how it's rendered.
|
||||
|
||||
- There are 3 representations of a message, each scoped to a different layer:
|
||||
1. **ChatMessage** — The database model. Should be converted into ChatMessageSimple early and never passed deep into the flow.
|
||||
2. **ChatMessageSimple** — The canonical data model used throughout the codebase. This is the rich, full-featured representation
|
||||
of a message. Any modifications or additions to message structure should be made here.
|
||||
3. **LanguageModelInput** — The LLM-facing representation. Intentionally minimal so the LLM interface layer stays clean and
|
||||
easy to maintain/extend.
|
||||
|
||||
@@ -1,19 +1,28 @@
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import SearchParams
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Type alias for search doc deduplication key
|
||||
# Simple key: just document_id (str)
|
||||
@@ -161,112 +170,45 @@ class ChatStateContainer:
|
||||
return self._emitted_citations.copy()
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
with event streaming capabilities.
|
||||
class AvailableFiles(BaseModel):
|
||||
"""Separated file IDs for the FileReaderTool so it knows which loader to use."""
|
||||
|
||||
The wrapped function should accept emitter as first arg and use it to emit
|
||||
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
|
||||
# IDs from the ``user_file`` table (project / persona-attached files).
|
||||
user_file_ids: list[UUID] = []
|
||||
# IDs from the ``file_record`` table (chat-attached files).
|
||||
chat_file_ids: list[UUID] = []
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
pass
|
||||
"""
|
||||
@dataclass(frozen=True)
|
||||
class ChatTurnSetup:
|
||||
"""Immutable context produced by ``build_chat_turn`` and consumed by ``_run_models``."""
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
# Run the function in a background thread
|
||||
thread = run_in_background(run_with_exception_capture)
|
||||
|
||||
pkt: Packet | None = None
|
||||
last_turn_index = 0 # Track the highest turn_index seen for stop packet
|
||||
last_cancel_check = time.monotonic()
|
||||
cancel_check_interval = 0.3 # Check for cancellation every 300ms
|
||||
try:
|
||||
while True:
|
||||
# Poll queue with 300ms timeout for natural stop signal checking
|
||||
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
|
||||
try:
|
||||
pkt = emitter.bus.get(timeout=0.3)
|
||||
except Empty:
|
||||
if not is_connected():
|
||||
# Stop signal detected
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = time.monotonic()
|
||||
continue
|
||||
|
||||
if pkt is not None:
|
||||
# Track the highest turn_index for the stop packet
|
||||
if pkt.placement and pkt.placement.turn_index > last_turn_index:
|
||||
last_turn_index = pkt.placement.turn_index
|
||||
|
||||
if isinstance(pkt.obj, OverallStop):
|
||||
yield pkt
|
||||
break
|
||||
elif isinstance(pkt.obj, PacketException):
|
||||
raise pkt.obj.exception
|
||||
else:
|
||||
yield pkt
|
||||
|
||||
# Check for cancellation periodically even when packets are flowing
|
||||
# This ensures stop signal is checked during active streaming
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_cancel_check >= cancel_check_interval:
|
||||
if not is_connected():
|
||||
# Stop signal detected during streaming
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = current_time
|
||||
finally:
|
||||
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
if is_connected():
|
||||
wait_on_background(thread)
|
||||
try:
|
||||
completion_callback(state_container)
|
||||
except Exception as e:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
new_msg_req: SendMessageRequest
|
||||
chat_session: ChatSession
|
||||
persona: Persona
|
||||
user_message: ChatMessage
|
||||
user_identity: LLMUserIdentity
|
||||
llms: list[LLM] # length 1 for single-model, N for multi-model
|
||||
model_display_names: list[str] # parallel to llms
|
||||
simple_chat_history: list[ChatMessageSimple]
|
||||
extracted_context_files: ExtractedContextFiles
|
||||
reserved_messages: list[ChatMessage] # length 1 for single, N for multi
|
||||
reserved_token_count: int
|
||||
search_params: SearchParams
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata]
|
||||
available_files: AvailableFiles
|
||||
tool_id_to_name_map: dict[int, str]
|
||||
forced_tool_id: int | None
|
||||
files: list[ChatLoadedFile]
|
||||
chat_files_for_tools: list[ChatFile]
|
||||
custom_agent_prompt: str | None
|
||||
user_memory_context: UserMemoryContext
|
||||
# For deep research: was the last assistant message a clarification request?
|
||||
skip_clarification: bool
|
||||
check_is_connected: Callable[[], bool]
|
||||
cache: CacheBackend
|
||||
# Execution params forwarded to per-model tool construction
|
||||
bypass_acl: bool
|
||||
slack_context: SlackContext | None
|
||||
custom_tool_additional_headers: dict[str, str] | None
|
||||
mcp_headers: dict[str, str] | None
|
||||
|
||||
@@ -1,19 +1,40 @@
|
||||
import threading
|
||||
from queue import Queue
|
||||
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
class Emitter:
|
||||
"""Use this inside tools to emit arbitrary UI progress."""
|
||||
"""Routes packets from LLM/tool execution to the ``_run_models`` drain loop.
|
||||
|
||||
def __init__(self, bus: Queue):
|
||||
self.bus = bus
|
||||
Tags every packet with ``model_index`` and places it on ``merged_queue``
|
||||
as a ``(model_idx, packet)`` tuple for ordered consumption downstream.
|
||||
|
||||
Args:
|
||||
merged_queue: Shared queue owned by ``_run_models``.
|
||||
model_idx: Index embedded in packet placements (``0`` for N=1 runs).
|
||||
drain_done: Optional event set by ``_run_models`` when the drain loop
|
||||
exits early (e.g. HTTP disconnect). When set, ``emit`` returns
|
||||
immediately so worker threads can exit fast.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
merged_queue: Queue[tuple[int, Packet | Exception | object]],
|
||||
model_idx: int = 0,
|
||||
drain_done: threading.Event | None = None,
|
||||
) -> None:
|
||||
self._model_idx = model_idx
|
||||
self._merged_queue = merged_queue
|
||||
self._drain_done = drain_done
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
self.bus.put(packet) # Thread-safe
|
||||
|
||||
|
||||
def get_default_emitter() -> Emitter:
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
return emitter
|
||||
if self._drain_done is not None and self._drain_done.is_set():
|
||||
return
|
||||
base = packet.placement or Placement(turn_index=0)
|
||||
tagged = Packet(
|
||||
placement=base.model_copy(update={"model_index": self._model_idx}),
|
||||
obj=packet.obj,
|
||||
)
|
||||
self._merged_queue.put((self._model_idx, tagged))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -286,11 +286,9 @@ USING_AWS_MANAGED_OPENSEARCH = (
|
||||
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
|
||||
)
|
||||
# Profiling adds some overhead to OpenSearch operations. This overhead is
|
||||
# unknown right now. It is enabled by default so we can get useful logs for
|
||||
# investigating slow queries. We may never disable it if the overhead is
|
||||
# minimal.
|
||||
# unknown right now. Defaults to True.
|
||||
OPENSEARCH_PROFILING_DISABLED = (
|
||||
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
|
||||
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "true").lower() == "true"
|
||||
)
|
||||
# Whether to disable match highlights for OpenSearch. Defaults to True for now
|
||||
# as we investigate query performance.
|
||||
@@ -942,9 +940,20 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
)
|
||||
|
||||
VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "15")
|
||||
# This is the timeout for the client side of the Vespa migration task. When
|
||||
# exceeded, an exception is raised in our code. This value should be higher than
|
||||
# VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT.
|
||||
VESPA_MIGRATION_REQUEST_TIMEOUT_S = int(
|
||||
os.environ.get("VESPA_MIGRATION_REQUEST_TIMEOUT_S") or "120"
|
||||
)
|
||||
# This is the timeout Vespa uses on the server side to know when to wrap up its
|
||||
# traversal and try to report partial results. This differs from the client
|
||||
# timeout above which raises an exception in our code when exceeded. This
|
||||
# timeout allows Vespa to return gracefully. This value should be lower than
|
||||
# VESPA_MIGRATION_REQUEST_TIMEOUT_S. Formatted as <number of seconds>s.
|
||||
VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT = os.environ.get(
|
||||
"VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT", "110s"
|
||||
)
|
||||
|
||||
SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
|
||||
|
||||
|
||||
@@ -42,9 +42,6 @@ from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_all_files_in_my_drive_and_shared,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_external_access_for_folder
|
||||
from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_files_by_web_view_links_batch,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
@@ -73,13 +70,11 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import NormalizationResult
|
||||
from onyx.connectors.interfaces import Resolver
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -207,9 +202,7 @@ class DriveIdStatus(Enum):
|
||||
|
||||
|
||||
class GoogleDriveConnector(
|
||||
SlimConnectorWithPermSync,
|
||||
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
|
||||
Resolver,
|
||||
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1672,82 +1665,6 @@ class GoogleDriveConnector(
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
|
||||
@override
|
||||
def resolve_errors(
|
||||
self,
|
||||
errors: list[ConnectorFailure],
|
||||
include_permissions: bool = False,
|
||||
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
|
||||
if self._creds is None or self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Credentials missing, should not call this method before calling load_credentials"
|
||||
)
|
||||
|
||||
logger.info(f"Resolving {len(errors)} errors")
|
||||
doc_ids = [
|
||||
failure.failed_document.document_id
|
||||
for failure in errors
|
||||
if failure.failed_document
|
||||
]
|
||||
service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
field_type = (
|
||||
DriveFileFieldType.WITH_PERMISSIONS
|
||||
if include_permissions or self.exclude_domain_link_only
|
||||
else DriveFileFieldType.STANDARD
|
||||
)
|
||||
batch_result = get_files_by_web_view_links_batch(service, doc_ids, field_type)
|
||||
|
||||
for doc_id, error in batch_result.errors.items():
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=doc_id,
|
||||
),
|
||||
failure_message=f"Failed to retrieve file during error resolution: {error}",
|
||||
exception=error,
|
||||
)
|
||||
|
||||
permission_sync_context = (
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
)
|
||||
if include_permissions
|
||||
else None
|
||||
)
|
||||
|
||||
retrieved_files = [
|
||||
RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=self.primary_admin_email,
|
||||
completion_stage=DriveRetrievalStage.DONE,
|
||||
)
|
||||
for file in batch_result.files.values()
|
||||
]
|
||||
|
||||
yield from self._get_new_ancestors_for_files(
|
||||
files=retrieved_files,
|
||||
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
permission_sync_context=permission_sync_context,
|
||||
add_prefix=True,
|
||||
)
|
||||
|
||||
func_with_args = [
|
||||
(
|
||||
self._convert_retrieved_file_to_document,
|
||||
(rf, permission_sync_context),
|
||||
)
|
||||
for rf in retrieved_files
|
||||
]
|
||||
results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
)
|
||||
for result in results:
|
||||
if result is not None:
|
||||
yield result
|
||||
|
||||
def _extract_slim_docs_from_google_drive(
|
||||
self,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
|
||||
@@ -9,7 +9,6 @@ from urllib.parse import urlparse
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from googleapiclient.http import BatchHttpRequest # type: ignore
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
@@ -61,8 +60,6 @@ SLIM_FILE_FIELDS = (
|
||||
)
|
||||
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||
|
||||
MAX_BATCH_SIZE = 100
|
||||
|
||||
HIERARCHY_FIELDS = "id, name, parents, webViewLink, mimeType, driveId"
|
||||
|
||||
HIERARCHY_FIELDS_WITH_PERMISSIONS = (
|
||||
@@ -219,7 +216,7 @@ def get_external_access_for_folder(
|
||||
|
||||
|
||||
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
"""Get the appropriate fields string for files().list() based on the field type enum."""
|
||||
"""Get the appropriate fields string based on the field type enum"""
|
||||
if field_type == DriveFileFieldType.SLIM:
|
||||
return SLIM_FILE_FIELDS
|
||||
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
|
||||
@@ -228,25 +225,6 @@ def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
return FILE_FIELDS
|
||||
|
||||
|
||||
def _extract_single_file_fields(list_fields: str) -> str:
|
||||
"""Convert a files().list() fields string to one suitable for files().get().
|
||||
|
||||
List fields look like "nextPageToken, files(field1, field2, ...)"
|
||||
Single-file fields should be just "field1, field2, ..."
|
||||
"""
|
||||
start = list_fields.find("files(")
|
||||
if start == -1:
|
||||
return list_fields
|
||||
inner_start = start + len("files(")
|
||||
inner_end = list_fields.rfind(")")
|
||||
return list_fields[inner_start:inner_end]
|
||||
|
||||
|
||||
def _get_single_file_fields(field_type: DriveFileFieldType) -> str:
|
||||
"""Get the appropriate fields string for files().get() based on the field type enum."""
|
||||
return _extract_single_file_fields(_get_fields_for_file_type(field_type))
|
||||
|
||||
|
||||
def _get_files_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
@@ -558,74 +536,3 @@ def get_file_by_web_view_link(
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
class BatchRetrievalResult:
|
||||
"""Result of a batch file retrieval, separating successes from errors."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.files: dict[str, GoogleDriveFileType] = {}
|
||||
self.errors: dict[str, Exception] = {}
|
||||
|
||||
|
||||
def get_files_by_web_view_links_batch(
|
||||
service: GoogleDriveService,
|
||||
web_view_links: list[str],
|
||||
field_type: DriveFileFieldType,
|
||||
) -> BatchRetrievalResult:
|
||||
"""Retrieve multiple Google Drive files by webViewLink using the batch API.
|
||||
|
||||
Returns a BatchRetrievalResult containing successful file retrievals
|
||||
and errors for any files that could not be fetched.
|
||||
Automatically splits into chunks of MAX_BATCH_SIZE.
|
||||
"""
|
||||
fields = _get_single_file_fields(field_type)
|
||||
if len(web_view_links) <= MAX_BATCH_SIZE:
|
||||
return _get_files_by_web_view_links_batch(service, web_view_links, fields)
|
||||
|
||||
combined = BatchRetrievalResult()
|
||||
for i in range(0, len(web_view_links), MAX_BATCH_SIZE):
|
||||
chunk = web_view_links[i : i + MAX_BATCH_SIZE]
|
||||
chunk_result = _get_files_by_web_view_links_batch(service, chunk, fields)
|
||||
combined.files.update(chunk_result.files)
|
||||
combined.errors.update(chunk_result.errors)
|
||||
return combined
|
||||
|
||||
|
||||
def _get_files_by_web_view_links_batch(
|
||||
service: GoogleDriveService,
|
||||
web_view_links: list[str],
|
||||
fields: str,
|
||||
) -> BatchRetrievalResult:
|
||||
"""Single-batch implementation."""
|
||||
|
||||
result = BatchRetrievalResult()
|
||||
|
||||
def callback(
|
||||
request_id: str,
|
||||
response: GoogleDriveFileType,
|
||||
exception: Exception | None,
|
||||
) -> None:
|
||||
if exception:
|
||||
logger.warning(f"Error retrieving file {request_id}: {exception}")
|
||||
result.errors[request_id] = exception
|
||||
else:
|
||||
result.files[request_id] = response
|
||||
|
||||
batch = cast(BatchHttpRequest, service.new_batch_http_request(callback=callback))
|
||||
|
||||
for web_view_link in web_view_links:
|
||||
try:
|
||||
file_id = _extract_file_id_from_web_view_link(web_view_link)
|
||||
request = service.files().get(
|
||||
fileId=file_id,
|
||||
supportsAllDrives=True,
|
||||
fields=fields,
|
||||
)
|
||||
batch.add(request, request_id=web_view_link)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to extract file ID from {web_view_link}: {e}")
|
||||
result.errors[web_view_link] = e
|
||||
|
||||
batch.execute()
|
||||
return result
|
||||
|
||||
@@ -298,22 +298,6 @@ class CheckpointedConnectorWithPermSync(CheckpointedConnector[CT]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Resolver(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def resolve_errors(
|
||||
self,
|
||||
errors: list[ConnectorFailure],
|
||||
include_permissions: bool = False,
|
||||
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
|
||||
"""Attempts to yield back ALL the documents described by the errors, no checkpointing.
|
||||
|
||||
Caller's responsibility is to delete the old ConnectorFailures and replace with the new ones.
|
||||
If include_permissions is True, the documents will have permissions synced.
|
||||
May also yield HierarchyNode objects for ancestor folders of resolved documents.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HierarchyConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def load_hierarchy(
|
||||
|
||||
@@ -1,24 +1,33 @@
|
||||
import uuid
|
||||
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import ApiKeyDescriptor
|
||||
from onyx.auth.api_key import build_displayable_api_key
|
||||
from onyx.auth.api_key import generate_api_key
|
||||
from onyx.auth.api_key import hash_api_key
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_api_key_email_pattern() -> str:
|
||||
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
@@ -55,7 +64,6 @@ async def fetch_user_for_api_key(
|
||||
select(User)
|
||||
.join(ApiKey, ApiKey.user_id == User.id)
|
||||
.where(ApiKey.hashed_api_key == hashed_api_key)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
|
||||
|
||||
@@ -87,6 +95,7 @@ def insert_api_key(
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=api_key_args.role,
|
||||
account_type=AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
db_session.add(api_key_user_row)
|
||||
|
||||
@@ -99,7 +108,18 @@ def insert_api_key(
|
||||
)
|
||||
db_session.add(api_key_row)
|
||||
|
||||
# Assign the API key virtual user to the appropriate default group
|
||||
# before commit so everything is atomic.
|
||||
# LIMITED role service accounts should have no group membership.
|
||||
if api_key_args.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user_row,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
api_key_id=api_key_row.id,
|
||||
api_key_role=api_key_user_row.role,
|
||||
@@ -126,7 +146,33 @@ def update_api_key(
|
||||
|
||||
email_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
||||
api_key_user.email = get_api_key_fake_email(email_name, str(api_key_user.id))
|
||||
|
||||
old_role = api_key_user.role
|
||||
api_key_user.role = api_key_args.role
|
||||
|
||||
# Reconcile default-group membership when the role changes.
|
||||
if old_role != api_key_args.role:
|
||||
# Remove from all default groups first.
|
||||
delete_stmt = delete(User__UserGroup).where(
|
||||
User__UserGroup.user_id == api_key_user.id,
|
||||
User__UserGroup.user_group_id.in_(
|
||||
select(UserGroup.id).where(UserGroup.is_default.is_(True))
|
||||
),
|
||||
)
|
||||
db_session.execute(delete_stmt)
|
||||
|
||||
# Re-assign to the correct default group (skip for LIMITED).
|
||||
if api_key_args.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
else:
|
||||
# No group assigned for LIMITED, but we still need to recompute
|
||||
# since we just removed the old default-group membership above.
|
||||
recompute_user_permissions__no_commit(api_key_user.id, db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
|
||||
@@ -13,7 +13,6 @@ from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -98,11 +97,6 @@ async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
|
||||
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
|
||||
async def _get_user(self, statement: Select) -> UP | None:
|
||||
statement = statement.options(selectinload(User.memories))
|
||||
results = await self.session.execute(statement)
|
||||
return results.unique().scalar_one_or_none()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
create_dict: Dict[str, Any],
|
||||
|
||||
@@ -631,6 +631,91 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
|
||||
|
||||
Also advances ``latest_child_message_id`` so the preferred response becomes
|
||||
the active branch for any subsequent messages in the conversation.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
|
||||
preferred response is being set.
|
||||
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
|
||||
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
|
||||
|
||||
Raises:
|
||||
ValueError: If either message is not found, if ``user_message_id`` does not
|
||||
refer to a USER message, or if the assistant message is not a direct child
|
||||
of the user message.
|
||||
"""
|
||||
user_msg = db_session.get(ChatMessage, user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
user_msg.latest_child_message_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -853,6 +938,8 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -13,19 +13,26 @@ class AccountType(str, PyEnum):
|
||||
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
|
||||
"""
|
||||
|
||||
STANDARD = "standard"
|
||||
BOT = "bot"
|
||||
EXT_PERM_USER = "ext_perm_user"
|
||||
SERVICE_ACCOUNT = "service_account"
|
||||
ANONYMOUS = "anonymous"
|
||||
STANDARD = "STANDARD"
|
||||
BOT = "BOT"
|
||||
EXT_PERM_USER = "EXT_PERM_USER"
|
||||
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
ANONYMOUS = "ANONYMOUS"
|
||||
|
||||
def is_web_login(self) -> bool:
|
||||
"""Whether this account type supports interactive web login."""
|
||||
return self not in (
|
||||
AccountType.BOT,
|
||||
AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
class GrantSource(str, PyEnum):
|
||||
"""How a permission grant was created."""
|
||||
|
||||
USER = "user"
|
||||
SCIM = "scim"
|
||||
SYSTEM = "system"
|
||||
USER = "USER"
|
||||
SCIM = "SCIM"
|
||||
SYSTEM = "SYSTEM"
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
|
||||
@@ -305,8 +305,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
account_type: Mapped[AccountType | None] = mapped_column(
|
||||
Enum(AccountType, native_enum=False), nullable=True
|
||||
account_type: Mapped[AccountType] = mapped_column(
|
||||
Enum(AccountType, native_enum=False),
|
||||
nullable=False,
|
||||
default=AccountType.STANDARD,
|
||||
server_default="STANDARD",
|
||||
)
|
||||
|
||||
"""
|
||||
@@ -353,6 +356,13 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
|
||||
effective_permissions: Mapped[list[str]] = mapped_column(
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
default=list,
|
||||
server_default=text("'[]'::jsonb"),
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
@@ -4016,7 +4026,12 @@ class PermissionGrant(Base):
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
permission: Mapped[Permission] = mapped_column(
|
||||
Enum(Permission, native_enum=False), nullable=False
|
||||
Enum(
|
||||
Permission,
|
||||
native_enum=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
grant_source: Mapped[GrantSource] = mapped_column(
|
||||
Enum(GrantSource, native_enum=False), nullable=False
|
||||
|
||||
@@ -8,7 +8,6 @@ from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.pat import build_displayable_pat
|
||||
@@ -47,7 +46,6 @@ async def fetch_user_for_pat(
|
||||
(PersonalAccessToken.expires_at.is_(None))
|
||||
| (PersonalAccessToken.expires_at > now)
|
||||
)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
95
backend/onyx/db/permissions.py
Normal file
95
backend/onyx/db/permissions.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
DB operations for recomputing user effective_permissions.
|
||||
|
||||
These live in onyx/db/ (not onyx/auth/) because they are pure DB operations
|
||||
that query PermissionGrant rows and update the User.effective_permissions
|
||||
JSONB column. Keeping them here avoids circular imports when called from
|
||||
other onyx/db/ modules such as users.py.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
|
||||
|
||||
def recompute_user_permissions__no_commit(
|
||||
user_ids: UUID | str | list[UUID] | list[str], db_session: Session
|
||||
) -> None:
|
||||
"""Recompute granted permissions for one or more users.
|
||||
|
||||
Accepts a single UUID or a list. Uses a single query regardless of
|
||||
how many users are passed, avoiding N+1 issues.
|
||||
|
||||
Stores only directly granted permissions — implication expansion
|
||||
happens at read time via get_effective_permissions().
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
if isinstance(user_ids, (UUID, str)):
|
||||
uid_list = [user_ids]
|
||||
else:
|
||||
uid_list = list(user_ids)
|
||||
|
||||
if not uid_list:
|
||||
return
|
||||
|
||||
# Single query to fetch ALL permissions for these users across ALL their
|
||||
# groups (a user may belong to multiple groups with different grants).
|
||||
rows = db_session.execute(
|
||||
select(User__UserGroup.user_id, PermissionGrant.permission)
|
||||
.join(
|
||||
PermissionGrant,
|
||||
PermissionGrant.group_id == User__UserGroup.user_group_id,
|
||||
)
|
||||
.where(
|
||||
User__UserGroup.user_id.in_(uid_list),
|
||||
PermissionGrant.is_deleted.is_(False),
|
||||
)
|
||||
).all()
|
||||
|
||||
# Group permissions by user; users with no grants get an empty set.
|
||||
perms_by_user: dict[UUID | str, set[str]] = defaultdict(set)
|
||||
for uid in uid_list:
|
||||
perms_by_user[uid] # ensure every user has an entry
|
||||
for uid, perm in rows:
|
||||
perms_by_user[uid].add(perm.value)
|
||||
|
||||
for uid, perms in perms_by_user.items():
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == uid) # type: ignore[arg-type]
|
||||
.values(effective_permissions=sorted(perms))
|
||||
)
|
||||
|
||||
|
||||
def recompute_permissions_for_group__no_commit(
|
||||
group_id: int, db_session: Session
|
||||
) -> None:
|
||||
"""Recompute granted permissions for all users in a group.
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
user_ids: list[UUID] = [
|
||||
uid
|
||||
for uid in db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == group_id,
|
||||
User__UserGroup.user_id.isnot(None),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
if uid is not None
|
||||
]
|
||||
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
recompute_user_permissions__no_commit(user_ids, db_session)
|
||||
@@ -5,11 +5,11 @@ from urllib.parse import urlencode
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import ONYX_UTM_SOURCE
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import batch_create_notifications
|
||||
from onyx.server.features.release_notes.constants import DOCS_CHANGELOG_BASE_URL
|
||||
@@ -49,7 +49,7 @@ def create_release_notifications_for_versions(
|
||||
db_session.scalars(
|
||||
select(User.id).where( # type: ignore
|
||||
User.is_active == True, # noqa: E712
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]),
|
||||
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER]),
|
||||
User.email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN).is_(False), # type: ignore[attr-defined]
|
||||
)
|
||||
).all()
|
||||
|
||||
@@ -9,12 +9,17 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import DefaultAppMode
|
||||
from onyx.db.enums import ThemePreference
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import Assistant__UserSpecificConfig
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.server.manage.models import MemoryItem
|
||||
from onyx.server.manage.models import UserSpecificAssistantPreference
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -23,13 +28,53 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_ROLE_TO_ACCOUNT_TYPE: dict[UserRole, AccountType] = {
|
||||
UserRole.SLACK_USER: AccountType.BOT,
|
||||
UserRole.EXT_PERM_USER: AccountType.EXT_PERM_USER,
|
||||
}
|
||||
|
||||
|
||||
def update_user_role(
|
||||
user: User,
|
||||
new_role: UserRole,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update a user's role in the database."""
|
||||
"""Update a user's role in the database.
|
||||
Dual-writes account_type to keep it in sync with role and
|
||||
reconciles default-group membership (Admin / Basic)."""
|
||||
old_role = user.role
|
||||
user.role = new_role
|
||||
# Note: setting account_type to BOT or EXT_PERM_USER causes
|
||||
# assign_user_to_default_groups__no_commit to early-return, which is
|
||||
# intentional — these account types should not be in default groups.
|
||||
if new_role in _ROLE_TO_ACCOUNT_TYPE:
|
||||
user.account_type = _ROLE_TO_ACCOUNT_TYPE[new_role]
|
||||
elif user.account_type in (AccountType.BOT, AccountType.EXT_PERM_USER):
|
||||
# Upgrading from a non-web-login account type to a web role
|
||||
user.account_type = AccountType.STANDARD
|
||||
|
||||
# Reconcile default-group membership when the role changes.
|
||||
if old_role != new_role:
|
||||
# Remove from all default groups first.
|
||||
db_session.execute(
|
||||
delete(User__UserGroup).where(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id.in_(
|
||||
select(UserGroup.id).where(UserGroup.is_default.is_(True))
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Re-assign to the correct default group (skip for LIMITED).
|
||||
if new_role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
user,
|
||||
is_admin=(new_role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -47,8 +92,16 @@ def activate_user(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Activate a user by setting is_active to True."""
|
||||
"""Activate a user by setting is_active to True.
|
||||
|
||||
Also reconciles default-group membership — the user may have been
|
||||
created while inactive or deactivated before the backfill migration.
|
||||
"""
|
||||
user.is_active = True
|
||||
if user.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -229,7 +282,9 @@ def get_memories_for_user(
|
||||
user_id: UUID,
|
||||
db_session: Session,
|
||||
) -> Sequence[Memory]:
|
||||
return db_session.scalars(select(Memory).where(Memory.user_id == user_id)).all()
|
||||
return db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.desc())
|
||||
).all()
|
||||
|
||||
|
||||
def update_user_pinned_assistants(
|
||||
|
||||
@@ -17,8 +17,9 @@ from sqlalchemy.sql.expression import or_
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
@@ -27,11 +28,17 @@ from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def validate_user_role_update(
|
||||
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
|
||||
requested_role: UserRole,
|
||||
current_role: UserRole,
|
||||
current_account_type: AccountType,
|
||||
explicit_override: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Validate that a user role update is valid.
|
||||
@@ -41,19 +48,18 @@ def validate_user_role_update(
|
||||
- requested role is a slack user
|
||||
- requested role is an external permissioned user
|
||||
- requested role is a limited user
|
||||
- current role is a slack user
|
||||
- current role is an external permissioned user
|
||||
- current account type is BOT (slack user)
|
||||
- current account type is EXT_PERM_USER
|
||||
- current role is a limited user
|
||||
"""
|
||||
|
||||
if current_role == UserRole.SLACK_USER:
|
||||
if current_account_type == AccountType.BOT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change a Slack User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if current_role == UserRole.EXT_PERM_USER:
|
||||
# This shouldn't happen, but just in case
|
||||
if current_account_type == AccountType.EXT_PERM_USER:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change an External Permissioned User's role, they must first login to Onyx via the web app.",
|
||||
@@ -298,6 +304,7 @@ def _generate_slack_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.SLACK_USER,
|
||||
account_type=AccountType.BOT,
|
||||
)
|
||||
|
||||
|
||||
@@ -306,8 +313,9 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
# If the user is an external permissioned user, we update it to a slack user
|
||||
if user.role == UserRole.EXT_PERM_USER:
|
||||
if user.account_type == AccountType.EXT_PERM_USER:
|
||||
user.role = UserRole.SLACK_USER
|
||||
user.account_type = AccountType.BOT
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
@@ -344,6 +352,7 @@ def _generate_ext_permissioned_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
@@ -375,6 +384,81 @@ def batch_add_ext_perm_user_if_not_exists(
|
||||
return all_users
|
||||
|
||||
|
||||
def assign_user_to_default_groups__no_commit(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
is_admin: bool = False,
|
||||
) -> None:
|
||||
"""Assign a newly created user to the appropriate default group.
|
||||
|
||||
Does NOT commit — callers must commit the session themselves so that
|
||||
group assignment can be part of the same transaction as user creation.
|
||||
|
||||
Args:
|
||||
is_admin: If True, assign to Admin default group; otherwise Basic.
|
||||
Callers determine this from their own context (e.g. user_count,
|
||||
admin email list, explicit choice). Defaults to False (Basic).
|
||||
"""
|
||||
if user.account_type in (
|
||||
AccountType.BOT,
|
||||
AccountType.EXT_PERM_USER,
|
||||
AccountType.ANONYMOUS,
|
||||
):
|
||||
return
|
||||
|
||||
target_group_name = "Admin" if is_admin else "Basic"
|
||||
|
||||
default_group = (
|
||||
db_session.query(UserGroup)
|
||||
.filter(
|
||||
UserGroup.name == target_group_name,
|
||||
UserGroup.is_default.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if default_group is None:
|
||||
raise RuntimeError(
|
||||
f"Default group '{target_group_name}' not found. "
|
||||
f"Cannot assign user {user.email} to a group. "
|
||||
f"Ensure the seed_default_groups migration has run."
|
||||
)
|
||||
|
||||
# Check if the user is already in the group
|
||||
existing = (
|
||||
db_session.query(User__UserGroup)
|
||||
.filter(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id == default_group.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing is not None:
|
||||
return
|
||||
|
||||
savepoint = db_session.begin_nested()
|
||||
try:
|
||||
db_session.add(
|
||||
User__UserGroup(
|
||||
user_id=user.id,
|
||||
user_group_id=default_group.id,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
except IntegrityError:
|
||||
# Race condition: another transaction inserted this membership
|
||||
# between our SELECT and INSERT. The savepoint isolates the failure
|
||||
# so the outer transaction (user creation) stays intact.
|
||||
savepoint.rollback()
|
||||
return
|
||||
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
logger.info(f"Assigned user {user.email} to default group '{default_group.name}'")
|
||||
|
||||
|
||||
def delete_user_from_db(
|
||||
user_to_delete: User,
|
||||
db_session: Session,
|
||||
@@ -421,13 +505,14 @@ def delete_user_from_db(
|
||||
def batch_get_user_groups(
|
||||
db_session: Session,
|
||||
user_ids: list[UUID],
|
||||
include_default: bool = False,
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Fetch group memberships for a batch of users in a single query.
|
||||
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
rows = db_session.execute(
|
||||
stmt = (
|
||||
select(
|
||||
User__UserGroup.user_id,
|
||||
UserGroup.id,
|
||||
@@ -435,7 +520,11 @@ def batch_get_user_groups(
|
||||
)
|
||||
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id.in_(user_ids))
|
||||
).all()
|
||||
)
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
|
||||
for user_id, group_id, group_name in rows:
|
||||
|
||||
@@ -37,10 +37,10 @@ M = 32 # Set relatively high for better accuracy.
|
||||
# we have a much higher chance of all 10 of the final desired docs showing up
|
||||
# and getting scored. In worse situations, the final 10 docs don't even show up
|
||||
# as the final 10 (worse than just a miss at the reranking step).
|
||||
# Defaults to 100 for now. Initially this defaulted to 750 but we were seeing
|
||||
# poor search performance.
|
||||
# Defaults to 500 for now. Initially this defaulted to 750 but we were seeing
|
||||
# poor search performance; bumped from 100 to 500 to improve recall.
|
||||
DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES = int(
|
||||
os.environ.get("DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES", 100)
|
||||
os.environ.get("DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES", 500)
|
||||
)
|
||||
|
||||
# Number of vectors to examine to decide the top k neighbors for the HNSW
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
|
||||
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
|
||||
from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE
|
||||
from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
from onyx.configs.app_configs import VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
@@ -335,6 +336,11 @@ def get_all_chunks_paginated(
|
||||
"format.tensors": "short-value",
|
||||
"slices": total_slices,
|
||||
"sliceId": slice_id,
|
||||
# When exceeded, Vespa should return gracefully with partial
|
||||
# results. Even if no hits are returned, Vespa should still return a
|
||||
# new continuation token representing a new spot in the linear
|
||||
# traversal.
|
||||
"timeout": VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT,
|
||||
}
|
||||
if continuation_token is not None:
|
||||
params["continuation"] = continuation_token
|
||||
@@ -343,6 +349,9 @@ def get_all_chunks_paginated(
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
with get_vespa_http_client(
|
||||
# When exceeded, an exception is raised in our code. No progress
|
||||
# is saved, and the task will retry this spot in the traversal
|
||||
# later.
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as http_client:
|
||||
response = http_client.get(url, params=params)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import csv
|
||||
import gc
|
||||
import io
|
||||
import json
|
||||
@@ -19,6 +20,7 @@ from zipfile import BadZipFile
|
||||
|
||||
import chardet
|
||||
import openpyxl
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.constants import ONYX_METADATA_FILENAME
|
||||
@@ -353,6 +355,94 @@ def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
return presentation.markdown
|
||||
|
||||
|
||||
def _worksheet_to_matrix(
|
||||
worksheet: Worksheet,
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Converts a singular worksheet to a matrix of values
|
||||
"""
|
||||
rows: list[list[str]] = []
|
||||
for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True):
|
||||
row = ["" if cell is None else str(cell) for cell in worksheet_row]
|
||||
rows.append(row)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def _clean_worksheet_matrix(matrix: list[list[str]]) -> list[list[str]]:
|
||||
"""
|
||||
Cleans a worksheet matrix by removing rows if there are N consecutive empty
|
||||
rows and removing cols if there are M consecutive empty columns
|
||||
"""
|
||||
MAX_EMPTY_ROWS = 2 # Runs longer than this are capped to max_empty; shorter runs are preserved as-is
|
||||
MAX_EMPTY_COLS = 2
|
||||
|
||||
# Row cleanup
|
||||
matrix = _remove_empty_runs(matrix, max_empty=MAX_EMPTY_ROWS)
|
||||
|
||||
if not matrix:
|
||||
return matrix
|
||||
|
||||
# Column cleanup — determine which columns to keep without transposing.
|
||||
num_cols = len(matrix[0])
|
||||
keep_cols = _columns_to_keep(matrix, num_cols, max_empty=MAX_EMPTY_COLS)
|
||||
if len(keep_cols) < num_cols:
|
||||
matrix = [[row[c] for c in keep_cols] for row in matrix]
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def _columns_to_keep(
|
||||
matrix: list[list[str]], num_cols: int, max_empty: int
|
||||
) -> list[int]:
|
||||
"""Return the indices of columns to keep after removing empty-column runs.
|
||||
|
||||
Uses the same logic as ``_remove_empty_runs`` but operates on column
|
||||
indices so no transpose is needed.
|
||||
"""
|
||||
kept: list[int] = []
|
||||
empty_buffer: list[int] = []
|
||||
|
||||
for col_idx in range(num_cols):
|
||||
col_is_empty = all(not row[col_idx] for row in matrix)
|
||||
if col_is_empty:
|
||||
empty_buffer.append(col_idx)
|
||||
else:
|
||||
kept.extend(empty_buffer[:max_empty])
|
||||
kept.append(col_idx)
|
||||
empty_buffer = []
|
||||
|
||||
return kept
|
||||
|
||||
|
||||
def _remove_empty_runs(
|
||||
rows: list[list[str]],
|
||||
max_empty: int,
|
||||
) -> list[list[str]]:
|
||||
"""Removes entire runs of empty rows when the run length exceeds max_empty.
|
||||
|
||||
Leading empty runs are capped to max_empty, just like interior runs.
|
||||
Trailing empty rows are always dropped since there is no subsequent
|
||||
non-empty row to flush them.
|
||||
"""
|
||||
result: list[list[str]] = []
|
||||
empty_buffer: list[list[str]] = []
|
||||
|
||||
for row in rows:
|
||||
# Check if empty
|
||||
if not any(row):
|
||||
if len(empty_buffer) < max_empty:
|
||||
empty_buffer.append(row)
|
||||
else:
|
||||
# Add upto max empty rows onto the result - that's what we allow
|
||||
result.extend(empty_buffer[:max_empty])
|
||||
# Add the new non-empty row
|
||||
result.append(row)
|
||||
empty_buffer = []
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
# TODO: switch back to this approach in a few months when markitdown
|
||||
# fixes their handling of excel files
|
||||
@@ -391,30 +481,15 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
|
||||
)
|
||||
return ""
|
||||
raise e
|
||||
raise
|
||||
|
||||
text_content = []
|
||||
for sheet in workbook.worksheets:
|
||||
rows = []
|
||||
num_empty_consecutive_rows = 0
|
||||
for row in sheet.iter_rows(min_row=1, values_only=True):
|
||||
row_str = ",".join(str(cell or "") for cell in row)
|
||||
|
||||
# Only add the row if there are any values in the cells
|
||||
if len(row_str) >= len(row):
|
||||
rows.append(row_str)
|
||||
num_empty_consecutive_rows = 0
|
||||
else:
|
||||
num_empty_consecutive_rows += 1
|
||||
|
||||
if num_empty_consecutive_rows > 100:
|
||||
# handle massive excel sheets with mostly empty cells
|
||||
logger.warning(
|
||||
f"Found {num_empty_consecutive_rows} empty rows in {file_name}, skipping rest of file"
|
||||
)
|
||||
break
|
||||
sheet_str = "\n".join(rows)
|
||||
text_content.append(sheet_str)
|
||||
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
|
||||
buf = io.StringIO()
|
||||
writer = csv.writer(buf, lineterminator="\n")
|
||||
writer.writerows(sheet_matrix)
|
||||
text_content.append(buf.getvalue().rstrip("\n"))
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
|
||||
|
||||
|
||||
@@ -1,33 +1,114 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
# TODO(@Bo-Onyx): define payload and response fields
|
||||
class DocumentIngestionSection(BaseModel):
|
||||
"""Represents a single section of a document — either text or image, not both.
|
||||
|
||||
Text section: set `text`, leave `image_file_id` null.
|
||||
Image section: set `image_file_id`, leave `text` null.
|
||||
"""
|
||||
|
||||
text: str | None = Field(
|
||||
default=None,
|
||||
description="Text content of this section. Set for text sections, null for image sections.",
|
||||
)
|
||||
link: str | None = Field(
|
||||
default=None,
|
||||
description="Optional URL associated with this section. Preserve the original link from the payload if you want it retained.",
|
||||
)
|
||||
image_file_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Opaque identifier for an image stored in the file store. "
|
||||
"The image content is not included — this field signals that the section is an image. "
|
||||
"Hooks can use its presence to reorder or drop image sections, but cannot read or modify the image itself."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DocumentIngestionOwner(BaseModel):
|
||||
display_name: str | None = Field(
|
||||
default=None,
|
||||
description="Human-readable name of the owner.",
|
||||
)
|
||||
email: str | None = Field(
|
||||
default=None,
|
||||
description="Email address of the owner.",
|
||||
)
|
||||
|
||||
|
||||
class DocumentIngestionPayload(BaseModel):
|
||||
pass
|
||||
document_id: str = Field(
|
||||
description="Unique identifier for the document. Read-only — changes are ignored."
|
||||
)
|
||||
title: str | None = Field(description="Title of the document.")
|
||||
semantic_identifier: str = Field(
|
||||
description="Human-readable identifier used for display (e.g. file name, page title)."
|
||||
)
|
||||
source: str = Field(
|
||||
description=(
|
||||
"Connector source type (e.g. confluence, slack, google_drive). "
|
||||
"Read-only — changes are ignored. "
|
||||
"Full list of values: https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/configs/constants.py#L195"
|
||||
)
|
||||
)
|
||||
sections: list[DocumentIngestionSection] = Field(
|
||||
description="Sections of the document. Includes both text sections (text set, image_file_id null) and image sections (image_file_id set, text null)."
|
||||
)
|
||||
metadata: dict[str, list[str]] = Field(
|
||||
description="Key-value metadata attached to the document. Values are always a list of strings."
|
||||
)
|
||||
doc_updated_at: str | None = Field(
|
||||
description="ISO 8601 UTC timestamp of the last update at the source, or null if unknown. Example: '2024-03-15T10:30:00+00:00'."
|
||||
)
|
||||
primary_owners: list[DocumentIngestionOwner] | None = Field(
|
||||
description="Primary owners of the document, or null if not available."
|
||||
)
|
||||
secondary_owners: list[DocumentIngestionOwner] | None = Field(
|
||||
description="Secondary owners of the document, or null if not available."
|
||||
)
|
||||
|
||||
|
||||
class DocumentIngestionResponse(BaseModel):
|
||||
pass
|
||||
# Intentionally permissive — customer endpoints may return extra fields.
|
||||
sections: list[DocumentIngestionSection] | None = Field(
|
||||
description="The sections to index, in the desired order. Reorder, drop, or modify sections freely. Null or empty list drops the document."
|
||||
)
|
||||
rejection_reason: str | None = Field(
|
||||
default=None,
|
||||
description="Logged when sections is null or empty. Falls back to a generic message if omitted.",
|
||||
)
|
||||
|
||||
|
||||
class DocumentIngestionSpec(HookPointSpec):
|
||||
"""Hook point that runs during document ingestion.
|
||||
"""Hook point that runs on every document before it enters the indexing pipeline.
|
||||
|
||||
# TODO(@Bo-Onyx): define call site, input/output schema, and timeout budget.
|
||||
Call site: immediately after Onyx's internal validation and before the
|
||||
indexing pipeline begins — no partial writes have occurred yet.
|
||||
|
||||
If a Document Ingestion hook is configured, it takes precedence —
|
||||
Document Ingestion Light will not run. Configure only one per deployment.
|
||||
|
||||
Supported use cases:
|
||||
- Document filtering: drop documents based on content or metadata
|
||||
- Content rewriting: redact PII or normalize text before indexing
|
||||
"""
|
||||
|
||||
hook_point = HookPoint.DOCUMENT_INGESTION
|
||||
display_name = "Document Ingestion"
|
||||
description = "Runs during document ingestion. Allows filtering or transforming documents before indexing."
|
||||
description = (
|
||||
"Runs on every document before it enters the indexing pipeline. "
|
||||
"Allows filtering, rewriting, or dropping documents."
|
||||
)
|
||||
default_timeout_seconds = 30.0
|
||||
fail_hard_description = "The document will not be indexed."
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
# TODO(Bo-Onyx): update later
|
||||
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.ue263ual5vdi"
|
||||
docs_url = "https://docs.onyx.app/admins/advanced_configs/hook_extensions#document-ingestion"
|
||||
|
||||
payload_model = DocumentIngestionPayload
|
||||
response_model = DocumentIngestionResponse
|
||||
|
||||
@@ -65,8 +65,9 @@ class QueryProcessingSpec(HookPointSpec):
|
||||
"The query will be blocked and the user will see an error message."
|
||||
)
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
# TODO(Bo-Onyx): update later
|
||||
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.g2r1a1699u87"
|
||||
docs_url = (
|
||||
"https://docs.onyx.app/admins/advanced_configs/hook_extensions#query-processing"
|
||||
)
|
||||
|
||||
payload_model = QueryProcessingPayload
|
||||
response_model = QueryProcessingResponse
|
||||
|
||||
@@ -33,6 +33,7 @@ from onyx.connectors.models import TextSection
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.document import upsert_documents
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.models import Document as DBDocument
|
||||
from onyx.db.models import IndexModelStatus
|
||||
@@ -47,6 +48,13 @@ from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.hooks.executor import execute_hook
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.hooks.points.document_ingestion import DocumentIngestionOwner
|
||||
from onyx.hooks.points.document_ingestion import DocumentIngestionPayload
|
||||
from onyx.hooks.points.document_ingestion import DocumentIngestionResponse
|
||||
from onyx.hooks.points.document_ingestion import DocumentIngestionSection
|
||||
from onyx.indexing.chunk_batch_store import ChunkBatchStore
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import embed_chunks_with_failure_handling
|
||||
@@ -297,6 +305,7 @@ def index_doc_batch_with_handler(
|
||||
document_batch: list[Document],
|
||||
request_id: str | None,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
adapter: IndexingBatchAdapter,
|
||||
ignore_time_skip: bool = False,
|
||||
enable_contextual_rag: bool = False,
|
||||
@@ -310,6 +319,7 @@ def index_doc_batch_with_handler(
|
||||
document_batch=document_batch,
|
||||
request_id=request_id,
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
adapter=adapter,
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
enable_contextual_rag=enable_contextual_rag,
|
||||
@@ -785,6 +795,132 @@ def _verify_indexing_completeness(
|
||||
)
|
||||
|
||||
|
||||
def _apply_document_ingestion_hook(
|
||||
documents: list[Document],
|
||||
db_session: Session,
|
||||
) -> list[Document]:
|
||||
"""Apply the Document Ingestion hook to each document in the batch.
|
||||
|
||||
- HookSkipped / HookSoftFailed → document passes through unchanged.
|
||||
- Response with sections=None → document is dropped (logged).
|
||||
- Response with sections → document sections are replaced with the hook's output.
|
||||
"""
|
||||
|
||||
def _build_payload(doc: Document) -> DocumentIngestionPayload:
|
||||
return DocumentIngestionPayload(
|
||||
document_id=doc.id or "",
|
||||
title=doc.title,
|
||||
semantic_identifier=doc.semantic_identifier,
|
||||
source=doc.source.value if doc.source is not None else "",
|
||||
sections=[
|
||||
DocumentIngestionSection(
|
||||
text=s.text if isinstance(s, TextSection) else None,
|
||||
link=s.link,
|
||||
image_file_id=(
|
||||
s.image_file_id if isinstance(s, ImageSection) else None
|
||||
),
|
||||
)
|
||||
for s in doc.sections
|
||||
],
|
||||
metadata={
|
||||
k: v if isinstance(v, list) else [v] for k, v in doc.metadata.items()
|
||||
},
|
||||
doc_updated_at=(
|
||||
doc.doc_updated_at.isoformat() if doc.doc_updated_at else None
|
||||
),
|
||||
primary_owners=(
|
||||
[
|
||||
DocumentIngestionOwner(
|
||||
display_name=o.get_semantic_name() or None,
|
||||
email=o.email,
|
||||
)
|
||||
for o in doc.primary_owners
|
||||
]
|
||||
if doc.primary_owners
|
||||
else None
|
||||
),
|
||||
secondary_owners=(
|
||||
[
|
||||
DocumentIngestionOwner(
|
||||
display_name=o.get_semantic_name() or None,
|
||||
email=o.email,
|
||||
)
|
||||
for o in doc.secondary_owners
|
||||
]
|
||||
if doc.secondary_owners
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def _apply_result(
|
||||
doc: Document,
|
||||
hook_result: DocumentIngestionResponse | HookSkipped | HookSoftFailed,
|
||||
) -> Document | None:
|
||||
"""Return the modified doc, original doc (skip/soft-fail), or None (drop)."""
|
||||
if isinstance(hook_result, (HookSkipped, HookSoftFailed)):
|
||||
return doc
|
||||
if not hook_result.sections:
|
||||
reason = hook_result.rejection_reason or "Document rejected by hook"
|
||||
logger.info(
|
||||
f"Document ingestion hook dropped document doc_id={doc.id!r}: {reason}"
|
||||
)
|
||||
return None
|
||||
new_sections: list[TextSection | ImageSection] = []
|
||||
for s in hook_result.sections:
|
||||
if s.image_file_id is not None:
|
||||
new_sections.append(
|
||||
ImageSection(image_file_id=s.image_file_id, link=s.link)
|
||||
)
|
||||
elif s.text is not None:
|
||||
new_sections.append(TextSection(text=s.text, link=s.link))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Document ingestion hook returned a section with neither text nor "
|
||||
f"image_file_id for doc_id={doc.id!r} — skipping section."
|
||||
)
|
||||
if not new_sections:
|
||||
logger.info(
|
||||
f"Document ingestion hook produced no valid sections for doc_id={doc.id!r} — dropping document."
|
||||
)
|
||||
return None
|
||||
return doc.model_copy(update={"sections": new_sections})
|
||||
|
||||
if not documents:
|
||||
return documents
|
||||
|
||||
# Run the hook for the first document. If it returns HookSkipped the hook
|
||||
# is not configured — skip the remaining N-1 DB lookups.
|
||||
first_doc = documents[0]
|
||||
first_payload = _build_payload(first_doc).model_dump()
|
||||
first_hook_result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.DOCUMENT_INGESTION,
|
||||
payload=first_payload,
|
||||
response_type=DocumentIngestionResponse,
|
||||
)
|
||||
if isinstance(first_hook_result, HookSkipped):
|
||||
return documents
|
||||
|
||||
result: list[Document] = []
|
||||
first_applied = _apply_result(first_doc, first_hook_result)
|
||||
if first_applied is not None:
|
||||
result.append(first_applied)
|
||||
|
||||
for doc in documents[1:]:
|
||||
payload = _build_payload(doc).model_dump()
|
||||
hook_result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.DOCUMENT_INGESTION,
|
||||
payload=payload,
|
||||
response_type=DocumentIngestionResponse,
|
||||
)
|
||||
applied = _apply_result(doc, hook_result)
|
||||
if applied is not None:
|
||||
result.append(applied)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@log_function_time(debug_only=True)
|
||||
def index_doc_batch(
|
||||
*,
|
||||
@@ -794,6 +930,7 @@ def index_doc_batch(
|
||||
document_indices: list[DocumentIndex],
|
||||
request_id: str | None,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
adapter: IndexingBatchAdapter,
|
||||
enable_contextual_rag: bool = False,
|
||||
llm: LLM | None = None,
|
||||
@@ -818,6 +955,7 @@ def index_doc_batch(
|
||||
)
|
||||
|
||||
filtered_documents = filter_fnc(document_batch)
|
||||
filtered_documents = _apply_document_ingestion_hook(filtered_documents, db_session)
|
||||
context = adapter.prepare(filtered_documents, ignore_time_skip)
|
||||
if not context:
|
||||
return IndexingPipelineResult.empty(len(filtered_documents))
|
||||
@@ -1005,6 +1143,7 @@ def run_indexing_pipeline(
|
||||
document_batch=document_batch,
|
||||
request_id=request_id,
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
adapter=adapter,
|
||||
enable_contextual_rag=enable_contextual_rag,
|
||||
llm=llm,
|
||||
|
||||
@@ -175,6 +175,28 @@ def _strip_tool_content_from_messages(
|
||||
return result
|
||||
|
||||
|
||||
def _fix_tool_user_message_ordering(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Insert a synthetic assistant message between tool and user messages.
|
||||
|
||||
Some models (e.g. Mistral on Azure) require strict message ordering where
|
||||
a user message cannot immediately follow a tool message. This function
|
||||
inserts a minimal assistant message to bridge the gap.
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
return messages
|
||||
|
||||
result: list[dict[str, Any]] = [messages[0]]
|
||||
for msg in messages[1:]:
|
||||
prev_role = result[-1].get("role")
|
||||
curr_role = msg.get("role")
|
||||
if prev_role == "tool" and curr_role == "user":
|
||||
result.append({"role": "assistant", "content": "Noted. Continuing."})
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
|
||||
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Check if any messages contain tool-related content blocks."""
|
||||
for msg in messages:
|
||||
@@ -576,6 +598,18 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
messages = _strip_tool_content_from_messages(messages)
|
||||
|
||||
# Some models (e.g. Mistral) reject a user message
|
||||
# immediately after a tool message. Insert a synthetic
|
||||
# assistant bridge message to satisfy the ordering
|
||||
# constraint. Check both the provider and the deployment/
|
||||
# model name to catch Mistral hosted on Azure.
|
||||
model_or_deployment = (
|
||||
self._deployment_name or self._model_version or ""
|
||||
).lower()
|
||||
is_mistral_model = is_mistral or "mistral" in model_or_deployment
|
||||
if is_mistral_model:
|
||||
messages = _fix_tool_user_message_ordering(messages)
|
||||
|
||||
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
|
||||
# reject requests where tool_choice is explicitly null.
|
||||
if tools and tool_choice is not None:
|
||||
|
||||
@@ -8,6 +8,24 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMOverride(BaseModel):
|
||||
"""Per-request LLM settings that override persona defaults.
|
||||
|
||||
All fields are optional — only the fields that differ from the persona's
|
||||
configured LLM need to be supplied. Used both over the wire (API requests)
|
||||
and for multi-model comparison, where one override is supplied per model.
|
||||
|
||||
Attributes:
|
||||
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
|
||||
When ``None``, the persona's default provider is used.
|
||||
model_version: Specific model version string (e.g. ``"gpt-4o"``).
|
||||
When ``None``, the persona's default model is used.
|
||||
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
|
||||
persona's default temperature is used.
|
||||
display_name: Human-readable label shown in the UI for this model,
|
||||
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
|
||||
when not set.
|
||||
"""
|
||||
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
@@ -3,10 +3,10 @@ import datetime
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
@@ -247,7 +247,7 @@ def handle_message(
|
||||
|
||||
elif (
|
||||
not existing_user.is_active
|
||||
and existing_user.role == UserRole.SLACK_USER
|
||||
and existing_user.account_type == AccountType.BOT
|
||||
):
|
||||
check_seat_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
|
||||
@@ -147,6 +147,7 @@ class UserInfo(BaseModel):
|
||||
is_anonymous_user: bool | None = None,
|
||||
tenant_info: TenantInfo | None = None,
|
||||
assistant_specific_configs: UserSpecificAssistantPreferences | None = None,
|
||||
memories: list[MemoryItem] | None = None,
|
||||
) -> "UserInfo":
|
||||
return cls(
|
||||
id=str(user.id),
|
||||
@@ -191,10 +192,7 @@ class UserInfo(BaseModel):
|
||||
role=user.personal_role or "",
|
||||
use_memories=user.use_memories,
|
||||
enable_memory_tool=user.enable_memory_tool,
|
||||
memories=[
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in (user.memories or [])
|
||||
],
|
||||
memories=memories or [],
|
||||
user_preferences=user.user_preferences or "",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -27,6 +27,7 @@ from onyx.auth.email_utils import send_user_email_invite
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -50,6 +51,7 @@ from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.auth import get_live_users_count
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
@@ -57,6 +59,7 @@ from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.user_preferences import deactivate_user
|
||||
from onyx.db.user_preferences import get_all_user_assistant_specific_configs
|
||||
from onyx.db.user_preferences import get_latest_access_token_for_user
|
||||
from onyx.db.user_preferences import get_memories_for_user
|
||||
from onyx.db.user_preferences import update_assistant_preferences
|
||||
from onyx.db.user_preferences import update_user_assistant_visibility
|
||||
from onyx.db.user_preferences import update_user_auto_scroll
|
||||
@@ -141,6 +144,7 @@ def set_user_role(
|
||||
validate_user_role_update(
|
||||
requested_role=requested_role,
|
||||
current_role=current_role,
|
||||
current_account_type=user_to_update.account_type,
|
||||
explicit_override=user_role_update_request.explicit_override,
|
||||
)
|
||||
|
||||
@@ -326,8 +330,8 @@ def list_all_users(
|
||||
if (include_api_keys or not is_api_key_email_address(user.email))
|
||||
]
|
||||
|
||||
slack_users = [user for user in users if user.role == UserRole.SLACK_USER]
|
||||
accepted_users = [user for user in users if user.role != UserRole.SLACK_USER]
|
||||
slack_users = [user for user in users if user.account_type == AccountType.BOT]
|
||||
accepted_users = [user for user in users if user.account_type != AccountType.BOT]
|
||||
|
||||
accepted_emails = {user.email for user in accepted_users}
|
||||
slack_users_emails = {user.email for user in slack_users}
|
||||
@@ -670,7 +674,7 @@ def list_all_users_basic_info(
|
||||
return [
|
||||
MinimalUserSnapshot(id=user.id, email=user.email)
|
||||
for user in users
|
||||
if user.role != UserRole.SLACK_USER
|
||||
if user.account_type != AccountType.BOT
|
||||
and (include_api_keys or not is_api_key_email_address(user.email))
|
||||
]
|
||||
|
||||
@@ -773,6 +777,13 @@ def _get_token_created_at(
|
||||
return get_current_token_creation_postgres(user, db_session)
|
||||
|
||||
|
||||
@router.get("/me/permissions", tags=PUBLIC_API_TAGS)
|
||||
def get_current_user_permissions(
|
||||
user: User = Depends(current_user),
|
||||
) -> list[str]:
|
||||
return sorted(p.value for p in get_effective_permissions(user))
|
||||
|
||||
|
||||
@router.get("/me", tags=PUBLIC_API_TAGS)
|
||||
def verify_user_logged_in(
|
||||
request: Request,
|
||||
@@ -823,6 +834,11 @@ def verify_user_logged_in(
|
||||
[],
|
||||
),
|
||||
)
|
||||
memories = [
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in get_memories_for_user(user.id, db_session)
|
||||
]
|
||||
|
||||
user_info = UserInfo.from_model(
|
||||
user,
|
||||
current_token_created_at=token_created_at,
|
||||
@@ -833,6 +849,7 @@ def verify_user_logged_in(
|
||||
new_tenant=new_tenant,
|
||||
invitation=tenant_invitation,
|
||||
),
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
return user_info
|
||||
@@ -930,7 +947,8 @@ def update_user_personalization_api(
|
||||
else user.enable_memory_tool
|
||||
)
|
||||
existing_memories = [
|
||||
MemoryItem(id=memory.id, content=memory.memory_text) for memory in user.memories
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in get_memories_for_user(user.id, db_session)
|
||||
]
|
||||
new_memories = (
|
||||
request.memories if request.memories is not None else existing_memories
|
||||
|
||||
@@ -7,6 +7,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
@@ -41,6 +42,7 @@ class FullUserSnapshot(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
role: UserRole
|
||||
account_type: AccountType
|
||||
is_active: bool
|
||||
password_configured: bool
|
||||
personal_name: str | None
|
||||
@@ -60,6 +62,7 @@ class FullUserSnapshot(BaseModel):
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
account_type=user.account_type,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
personal_name=user.personal_name,
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.chat.chat_utils import extract_headers
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -570,6 +575,46 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
# Narrowed here; is_multi_model already checked llm_overrides is not None
|
||||
llm_overrides = chat_message_req.llm_overrides or []
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in handle_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
if is_multi_model and not chat_message_req.stream:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -660,6 +705,30 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
try:
|
||||
# Ownership check: get_chat_message raises ValueError if the message
|
||||
# doesn't belong to this user, preventing cross-user mutation.
|
||||
get_chat_message(
|
||||
chat_message_id=request_body.user_message_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
@@ -2,11 +2,25 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class Placement(BaseModel):
|
||||
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
|
||||
"""Coordinates that identify where a streaming packet belongs in the UI.
|
||||
|
||||
The frontend uses these fields to route each packet to the correct turn,
|
||||
tool tab, agent sub-turn, and (in multi-model mode) response column.
|
||||
|
||||
Attributes:
|
||||
turn_index: Monotonically increasing index of the iterative reasoning block
|
||||
(e.g. tool call round) within this chat message. Lower values happened first.
|
||||
tab_index: Disambiguates parallel tool calls within the same turn so each
|
||||
tool's output can be displayed in its own tab.
|
||||
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
|
||||
top-level packets; an integer for tool-within-tool output.
|
||||
model_index: Which model this packet belongs to. ``0`` for single-model
|
||||
responses; ``0``, ``1``, or ``2`` for multi-model comparison. ``None``
|
||||
for pre-LLM setup packets (e.g. message ID info) that are yielded
|
||||
before any Emitter runs.
|
||||
"""
|
||||
|
||||
turn_index: int
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
|
||||
model_index: int | None = None
|
||||
|
||||
@@ -70,7 +70,7 @@ async def upsert_saml_user(email: str) -> User:
|
||||
try:
|
||||
user = await user_manager.get_by_email(email)
|
||||
# If user has a non-authenticated role, treat as non-existent
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
raise exceptions.UserNotExists()
|
||||
return user
|
||||
except exceptions.UserNotExists:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
@@ -708,7 +709,6 @@ def run_research_agent_calls(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from queue import Queue
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
@@ -744,8 +744,8 @@ if __name__ == "__main__":
|
||||
if user is None:
|
||||
raise ValueError("No users found in database. Please create a user first.")
|
||||
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
emitter_queue: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(merged_queue=emitter_queue)
|
||||
state_container = ChatStateContainer()
|
||||
|
||||
tool_dict = construct_tools(
|
||||
@@ -792,4 +792,4 @@ if __name__ == "__main__":
|
||||
print(result.intermediate_report)
|
||||
print("=" * 80)
|
||||
print(f"Citations: {result.citation_mapping}")
|
||||
print(f"Total packets emitted: {bus.qsize()}")
|
||||
print(f"Total packets emitted: {emitter_queue.qsize()}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import csv
|
||||
import json
|
||||
import queue
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from io import StringIO
|
||||
@@ -11,7 +12,6 @@ import requests
|
||||
from requests import JSONDecodeError
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -296,9 +296,9 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
|
||||
# Use default emitter if none provided
|
||||
# Use a discard emitter if none provided (packets go nowhere)
|
||||
if emitter is None:
|
||||
emitter = get_default_emitter()
|
||||
emitter = Emitter(merged_queue=queue.Queue())
|
||||
|
||||
return [
|
||||
CustomTool(
|
||||
@@ -367,7 +367,7 @@ if __name__ == "__main__":
|
||||
tools = build_custom_tools_from_openapi_schema_and_headers(
|
||||
tool_id=0, # dummy tool id
|
||||
openapi_schema=openapi_schema,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
dynamic_schema_info=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -458,6 +458,27 @@ def run_async_sync_no_cancel(coro: Awaitable[T]) -> T:
|
||||
return future.result()
|
||||
|
||||
|
||||
def run_multiple_in_background(
|
||||
funcs: list[Callable[[], None]],
|
||||
thread_name_prefix: str = "worker",
|
||||
) -> ThreadPoolExecutor:
|
||||
"""Submit multiple callables to a ``ThreadPoolExecutor`` with context propagation.
|
||||
|
||||
Copies the current ``contextvars`` context once and runs every callable
|
||||
inside that copy, which is important for preserving tenant IDs and other
|
||||
context-local state across threads.
|
||||
|
||||
Returns the executor so the caller can ``shutdown()`` when done.
|
||||
"""
|
||||
ctx = contextvars.copy_context()
|
||||
executor = ThreadPoolExecutor(
|
||||
max_workers=len(funcs), thread_name_prefix=thread_name_prefix
|
||||
)
|
||||
for func in funcs:
|
||||
executor.submit(ctx.run, func)
|
||||
return executor
|
||||
|
||||
|
||||
class TimeoutThread(threading.Thread, Generic[R]):
|
||||
def __init__(
|
||||
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
|
||||
@@ -1,239 +0,0 @@
|
||||
"""Tests for GoogleDriveConnector.resolve_errors against real Google Drive."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
ALL_EXPECTED_HIERARCHY_NODES,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_ID
|
||||
|
||||
_DRIVE_ID_MAPPING_PATH = os.path.join(
|
||||
os.path.dirname(__file__), "drive_id_mapping.json"
|
||||
)
|
||||
|
||||
|
||||
def _load_web_view_links(file_ids: list[int]) -> list[str]:
|
||||
with open(_DRIVE_ID_MAPPING_PATH) as f:
|
||||
mapping: dict[str, str] = json.load(f)
|
||||
return [mapping[str(fid)] for fid in file_ids]
|
||||
|
||||
|
||||
def _build_failures(web_view_links: list[str]) -> list[ConnectorFailure]:
|
||||
return [
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=link,
|
||||
document_link=link,
|
||||
),
|
||||
failure_message=f"Synthetic failure for {link}",
|
||||
)
|
||||
for link in web_view_links
|
||||
]
|
||||
|
||||
|
||||
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
|
||||
def test_resolve_single_file(
|
||||
mock_api_key: None, # noqa: ARG001
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
"""Resolve a single known file and verify we get back exactly one Document."""
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
shared_drive_urls=None,
|
||||
include_my_drives=True,
|
||||
my_drive_emails=None,
|
||||
shared_folder_urls=None,
|
||||
include_files_shared_with_me=False,
|
||||
)
|
||||
|
||||
web_view_links = _load_web_view_links([0])
|
||||
failures = _build_failures(web_view_links)
|
||||
|
||||
results = list(connector.resolve_errors(failures))
|
||||
|
||||
docs = [r for r in results if isinstance(r, Document)]
|
||||
new_failures = [r for r in results if isinstance(r, ConnectorFailure)]
|
||||
hierarchy_nodes = [r for r in results if isinstance(r, HierarchyNode)]
|
||||
|
||||
assert len(docs) == 1
|
||||
assert len(new_failures) == 0
|
||||
assert docs[0].semantic_identifier == "file_0.txt"
|
||||
|
||||
# Should yield at least one hierarchy node (the file's parent folder chain)
|
||||
assert len(hierarchy_nodes) > 0
|
||||
|
||||
|
||||
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
|
||||
def test_resolve_multiple_files(
|
||||
mock_api_key: None, # noqa: ARG001
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
"""Resolve multiple files across different folders via batch API."""
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
shared_drive_urls=None,
|
||||
include_my_drives=True,
|
||||
my_drive_emails=None,
|
||||
shared_folder_urls=None,
|
||||
include_files_shared_with_me=False,
|
||||
)
|
||||
|
||||
# Pick files from different folders: admin files (0-4), shared drive 1 (20-24), folder_2 (45-49)
|
||||
file_ids = [0, 1, 20, 21, 45]
|
||||
web_view_links = _load_web_view_links(file_ids)
|
||||
failures = _build_failures(web_view_links)
|
||||
|
||||
results = list(connector.resolve_errors(failures))
|
||||
|
||||
docs = [r for r in results if isinstance(r, Document)]
|
||||
new_failures = [r for r in results if isinstance(r, ConnectorFailure)]
|
||||
hierarchy_nodes = [r for r in results if isinstance(r, HierarchyNode)]
|
||||
|
||||
assert len(new_failures) == 0
|
||||
retrieved_names = {doc.semantic_identifier for doc in docs}
|
||||
expected_names = {f"file_{fid}.txt" for fid in file_ids}
|
||||
assert expected_names == retrieved_names
|
||||
|
||||
# Files span multiple folders, so we should get hierarchy nodes
|
||||
assert len(hierarchy_nodes) > 0
|
||||
|
||||
|
||||
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
|
||||
def test_resolve_hierarchy_nodes_are_valid(
|
||||
mock_api_key: None, # noqa: ARG001
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
"""Verify that hierarchy nodes from resolve_errors match expected structure."""
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
shared_drive_urls=None,
|
||||
include_my_drives=True,
|
||||
my_drive_emails=None,
|
||||
shared_folder_urls=None,
|
||||
include_files_shared_with_me=False,
|
||||
)
|
||||
|
||||
# File in folder_1 (inside shared_drive_1) — should walk up to shared_drive_1 root
|
||||
web_view_links = _load_web_view_links([25])
|
||||
failures = _build_failures(web_view_links)
|
||||
|
||||
results = list(connector.resolve_errors(failures))
|
||||
|
||||
hierarchy_nodes = [r for r in results if isinstance(r, HierarchyNode)]
|
||||
node_ids = {node.raw_node_id for node in hierarchy_nodes}
|
||||
|
||||
# File 25 is in folder_1 which is inside shared_drive_1.
|
||||
# The parent walk must yield at least these two ancestors.
|
||||
assert (
|
||||
FOLDER_1_ID in node_ids
|
||||
), f"Expected folder_1 ({FOLDER_1_ID}) in hierarchy nodes, got: {node_ids}"
|
||||
assert (
|
||||
SHARED_DRIVE_1_ID in node_ids
|
||||
), f"Expected shared_drive_1 ({SHARED_DRIVE_1_ID}) in hierarchy nodes, got: {node_ids}"
|
||||
|
||||
for node in hierarchy_nodes:
|
||||
if node.raw_node_id not in ALL_EXPECTED_HIERARCHY_NODES:
|
||||
continue
|
||||
expected = ALL_EXPECTED_HIERARCHY_NODES[node.raw_node_id]
|
||||
assert node.display_name == expected.display_name, (
|
||||
f"Display name mismatch for {node.raw_node_id}: "
|
||||
f"expected '{expected.display_name}', got '{node.display_name}'"
|
||||
)
|
||||
assert node.node_type == expected.node_type, (
|
||||
f"Node type mismatch for {node.raw_node_id}: "
|
||||
f"expected '{expected.node_type}', got '{node.node_type}'"
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
|
||||
def test_resolve_with_invalid_link(
|
||||
mock_api_key: None, # noqa: ARG001
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
"""Resolve with a mix of valid and invalid links — invalid ones yield ConnectorFailure."""
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
shared_drive_urls=None,
|
||||
include_my_drives=True,
|
||||
my_drive_emails=None,
|
||||
shared_folder_urls=None,
|
||||
include_files_shared_with_me=False,
|
||||
)
|
||||
|
||||
valid_links = _load_web_view_links([0])
|
||||
invalid_link = "https://drive.google.com/file/d/NONEXISTENT_FILE_ID_12345"
|
||||
failures = _build_failures(valid_links + [invalid_link])
|
||||
|
||||
results = list(connector.resolve_errors(failures))
|
||||
|
||||
docs = [r for r in results if isinstance(r, Document)]
|
||||
new_failures = [r for r in results if isinstance(r, ConnectorFailure)]
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].semantic_identifier == "file_0.txt"
|
||||
assert len(new_failures) == 1
|
||||
assert new_failures[0].failed_document is not None
|
||||
assert new_failures[0].failed_document.document_id == invalid_link
|
||||
|
||||
|
||||
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
|
||||
def test_resolve_empty_errors(
|
||||
mock_api_key: None, # noqa: ARG001
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
"""Resolving an empty error list should yield nothing."""
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
shared_drive_urls=None,
|
||||
include_my_drives=True,
|
||||
my_drive_emails=None,
|
||||
shared_folder_urls=None,
|
||||
include_files_shared_with_me=False,
|
||||
)
|
||||
|
||||
results = list(connector.resolve_errors([]))
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
|
||||
def test_resolve_entity_failures_are_skipped(
|
||||
mock_api_key: None, # noqa: ARG001
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
"""Entity failures (not document failures) should be skipped by resolve_errors."""
|
||||
from onyx.connectors.models import EntityFailure
|
||||
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
shared_drive_urls=None,
|
||||
include_my_drives=True,
|
||||
my_drive_emails=None,
|
||||
shared_folder_urls=None,
|
||||
include_files_shared_with_me=False,
|
||||
)
|
||||
|
||||
entity_failure = ConnectorFailure(
|
||||
failed_entity=EntityFailure(entity_id="some_stage"),
|
||||
failure_message="retrieval failure",
|
||||
)
|
||||
|
||||
results = list(connector.resolve_errors([entity_failure]))
|
||||
|
||||
assert len(results) == 0
|
||||
@@ -27,11 +27,13 @@ def create_placement(
|
||||
turn_index: int,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
model_index: int | None = 0,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
model_index=model_index,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -52,7 +53,12 @@ def tenant_context() -> Generator[None, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def create_test_user(db_session: Session, email_prefix: str) -> User:
|
||||
def create_test_user(
|
||||
db_session: Session,
|
||||
email_prefix: str,
|
||||
role: UserRole = UserRole.BASIC,
|
||||
account_type: AccountType = AccountType.STANDARD,
|
||||
) -> User:
|
||||
"""Helper to create a test user with a unique email"""
|
||||
# Use UUID to ensure unique email addresses
|
||||
unique_email = f"{email_prefix}_{uuid4().hex[:8]}@example.com"
|
||||
@@ -68,7 +74,8 @@ def create_test_user(db_session: Session, email_prefix: str) -> User:
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
role=role,
|
||||
account_type=account_type,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -13,16 +13,29 @@ from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import PublicExternalUserGroup
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.models import UserRole
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
def _create_ext_perm_user(db_session: Session, name: str) -> User:
|
||||
"""Create an external-permission user for group sync tests."""
|
||||
return create_test_user(
|
||||
db_session,
|
||||
name,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
def _create_test_connector_credential_pair(
|
||||
db_session: Session, source: DocumentSource = DocumentSource.GOOGLE_DRIVE
|
||||
) -> ConnectorCredentialPair:
|
||||
@@ -100,9 +113,9 @@ class TestPerformExternalGroupSync:
|
||||
def test_initial_group_sync(self, db_session: Session) -> None:
|
||||
"""Test syncing external groups for the first time (initial sync)"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user3 = create_test_user(db_session, "user3")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user3 = _create_ext_perm_user(db_session, "user3")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Mock external groups data as a generator that yields the expected groups
|
||||
@@ -175,9 +188,9 @@ class TestPerformExternalGroupSync:
|
||||
def test_update_existing_groups(self, db_session: Session) -> None:
|
||||
"""Test updating existing groups (adding/removing users)"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user3 = create_test_user(db_session, "user3")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user3 = _create_ext_perm_user(db_session, "user3")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with original groups
|
||||
@@ -272,8 +285,8 @@ class TestPerformExternalGroupSync:
|
||||
def test_remove_groups(self, db_session: Session) -> None:
|
||||
"""Test removing groups (groups that no longer exist in external system)"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with multiple groups
|
||||
@@ -357,7 +370,7 @@ class TestPerformExternalGroupSync:
|
||||
def test_empty_group_sync(self, db_session: Session) -> None:
|
||||
"""Test syncing when no groups are returned (all groups removed)"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with groups
|
||||
@@ -413,7 +426,7 @@ class TestPerformExternalGroupSync:
|
||||
# Create many test users
|
||||
users = []
|
||||
for i in range(150): # More than the batch size of 100
|
||||
users.append(create_test_user(db_session, f"user{i}"))
|
||||
users.append(_create_ext_perm_user(db_session, f"user{i}"))
|
||||
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
@@ -452,8 +465,8 @@ class TestPerformExternalGroupSync:
|
||||
def test_mixed_regular_and_public_groups(self, db_session: Session) -> None:
|
||||
"""Test syncing a mix of regular and public groups"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
def mixed_group_sync_func(
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import BuildSessionStatus
|
||||
from onyx.db.models import BuildSession
|
||||
from onyx.db.models import User
|
||||
@@ -52,6 +53,7 @@ def test_user(db_session: Session, tenant_context: None) -> User: # noqa: ARG00
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Tests that account_type is correctly set when creating users through
|
||||
the internal DB functions: add_slack_user_if_not_exists and
|
||||
batch_add_ext_perm_user_if_not_exists.
|
||||
|
||||
These functions are called by background workers (Slack bot, permission sync)
|
||||
and are not exposed via API endpoints, so they must be tested directly.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
|
||||
|
||||
def test_slack_user_creation_sets_account_type_bot(db_session: Session) -> None:
|
||||
"""add_slack_user_if_not_exists sets account_type=BOT and role=SLACK_USER."""
|
||||
user = add_slack_user_if_not_exists(db_session, "slack_acct_type@test.com")
|
||||
|
||||
assert user.role == UserRole.SLACK_USER
|
||||
assert user.account_type == AccountType.BOT
|
||||
|
||||
|
||||
def test_ext_perm_user_creation_sets_account_type(db_session: Session) -> None:
|
||||
"""batch_add_ext_perm_user_if_not_exists sets account_type=EXT_PERM_USER."""
|
||||
users = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session, ["extperm_acct_type@test.com"]
|
||||
)
|
||||
|
||||
assert len(users) == 1
|
||||
user = users[0]
|
||||
assert user.role == UserRole.EXT_PERM_USER
|
||||
assert user.account_type == AccountType.EXT_PERM_USER
|
||||
|
||||
|
||||
def test_ext_perm_to_slack_upgrade_updates_role_and_account_type(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""When an EXT_PERM_USER is upgraded to slack, both role and account_type update."""
|
||||
email = "ext_to_slack_acct_type@test.com"
|
||||
|
||||
# Create as ext_perm user first
|
||||
batch_add_ext_perm_user_if_not_exists(db_session, [email])
|
||||
|
||||
# Now "upgrade" via slack path
|
||||
user = add_slack_user_if_not_exists(db_session, email)
|
||||
|
||||
assert user.role == UserRole.SLACK_USER
|
||||
assert user.account_type == AccountType.BOT
|
||||
@@ -8,6 +8,7 @@ import pytest
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -46,6 +47,7 @@ def _create_admin(db_session: Session) -> User:
|
||||
is_superuser=True,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
account_type=AccountType.STANDARD,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -13,6 +13,7 @@ This test:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -20,7 +21,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
@@ -137,7 +138,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -200,7 +201,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -275,7 +276,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -350,7 +351,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -458,7 +459,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -541,7 +542,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
|
||||
@@ -8,6 +8,7 @@ Tests the priority logic for OAuth tokens when constructing custom tools:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
@@ -16,7 +17,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Persona
|
||||
@@ -174,7 +175,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -232,7 +233,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -284,7 +285,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -345,7 +346,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -416,7 +417,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -483,7 +484,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -536,7 +537,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
@@ -126,6 +126,15 @@ class UserManager:
|
||||
|
||||
return test_user
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(user: DATestUser) -> list[str]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/me/permissions",
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def is_role(
|
||||
user_to_verify: DATestUser,
|
||||
|
||||
@@ -104,13 +104,30 @@ class UserGroupManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(
|
||||
user_group: DATestUserGroup,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[str]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/permissions",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser,
|
||||
include_default: bool = False,
|
||||
) -> list[UserGroup]:
|
||||
params: dict[str, str] = {}
|
||||
if include_default:
|
||||
params["include_default"] = "true"
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
headers=user_performing_action.headers,
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [UserGroup(**ug) for ug in response.json()]
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -33,3 +37,120 @@ def test_limited(reset: None) -> None: # noqa: ARG001
|
||||
headers=api_key.headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def _get_service_account_account_type(
|
||||
admin_user: DATestUser,
|
||||
api_key_user_id: UUID,
|
||||
) -> AccountType:
|
||||
"""Fetch the account_type of a service account user via the user listing API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/users",
|
||||
headers=admin_user.headers,
|
||||
params={"include_api_keys": "true"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
user_id_str = str(api_key_user_id)
|
||||
for user in data["accepted"]:
|
||||
if user["id"] == user_id_str:
|
||||
return AccountType(user["account_type"])
|
||||
raise AssertionError(
|
||||
f"Service account user {user_id_str} not found in user listing"
|
||||
)
|
||||
|
||||
|
||||
def _get_default_group_user_ids(
|
||||
admin_user: DATestUser,
|
||||
) -> tuple[set[str], set[str]]:
|
||||
"""Return (admin_group_user_ids, basic_group_user_ids) from default groups."""
|
||||
all_groups = UserGroupManager.get_all(
|
||||
user_performing_action=admin_user,
|
||||
include_default=True,
|
||||
)
|
||||
admin_group = next(
|
||||
(g for g in all_groups if g.name == "Admin" and g.is_default), None
|
||||
)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.name == "Basic" and g.is_default), None
|
||||
)
|
||||
assert admin_group is not None, "Admin default group not found"
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
|
||||
admin_ids = {str(u.id) for u in admin_group.users}
|
||||
basic_ids = {str(u.id) for u in basic_group.users}
|
||||
return admin_ids, basic_ids
|
||||
|
||||
|
||||
def test_api_key_limited_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""LIMITED role API key: account_type is SERVICE_ACCOUNT, no group membership."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.LIMITED,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify no group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert (
|
||||
user_id_str not in admin_ids
|
||||
), "LIMITED API key should NOT be in Admin default group"
|
||||
assert (
|
||||
user_id_str not in basic_ids
|
||||
), "LIMITED API key should NOT be in Basic default group"
|
||||
|
||||
|
||||
def test_api_key_basic_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""BASIC role API key: account_type is SERVICE_ACCOUNT, in Basic group only."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.BASIC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify Basic group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert user_id_str in basic_ids, "BASIC API key should be in Basic default group"
|
||||
assert (
|
||||
user_id_str not in admin_ids
|
||||
), "BASIC API key should NOT be in Admin default group"
|
||||
|
||||
|
||||
def test_api_key_admin_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""ADMIN role API key: account_type is SERVICE_ACCOUNT, in Admin group only."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.ADMIN,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify Admin group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert user_id_str in admin_ids, "ADMIN API key should be in Admin default group"
|
||||
assert (
|
||||
user_id_str not in basic_ids
|
||||
), "ADMIN API key should NOT be in Basic default group"
|
||||
|
||||
@@ -4,11 +4,32 @@ import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _simulate_saml_login(email: str, admin_user: DATestUser) -> dict:
|
||||
"""Simulate a SAML login by calling the test upsert endpoint."""
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _get_basic_group_member_emails(admin_user: DATestUser) -> set[str]:
|
||||
"""Get the set of emails of all members in the Basic default group."""
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
basic_default = [g for g in all_groups if g.is_default and g.name == "Basic"]
|
||||
assert basic_default, "Basic default group not found"
|
||||
return {u.email for u in basic_default[0].users}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
@@ -49,15 +70,9 @@ def test_saml_user_conversion(reset: None) -> None: # noqa: ARG001
|
||||
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
|
||||
|
||||
# Simulate SAML login by calling the test endpoint
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": test_user_email},
|
||||
headers=admin_user.headers, # Use admin headers for authorization
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_data = _simulate_saml_login(test_user_email, admin_user)
|
||||
|
||||
# Verify the response indicates the role changed to BASIC
|
||||
user_data = response.json()
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify user role was changed in the database
|
||||
@@ -82,16 +97,237 @@ def test_saml_user_conversion(reset: None) -> None: # noqa: ARG001
|
||||
assert UserManager.is_role(slack_user, UserRole.SLACK_USER)
|
||||
|
||||
# Simulate SAML login again
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": slack_user_email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_data = _simulate_saml_login(slack_user_email, admin_user)
|
||||
|
||||
# Verify the response indicates the role changed to BASIC
|
||||
user_data = response.json()
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify the user's role was changed in the database
|
||||
assert UserManager.is_role(slack_user, UserRole.BASIC)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_user_conversion_sets_account_type_and_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that SAML login sets account_type to STANDARD when converting a
|
||||
non-web user (EXT_PERM_USER) and that the user receives the correct role
|
||||
(BASIC) after conversion.
|
||||
|
||||
This validates the permissions-migration-phase2 changes which ensure that:
|
||||
1. account_type is updated to 'standard' on SAML conversion
|
||||
2. The converted user is assigned to the Basic default group
|
||||
"""
|
||||
# Create an admin user (first user is automatically admin)
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# Create a user and set them as EXT_PERM_USER
|
||||
test_email = "ext_convert@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
|
||||
|
||||
# Simulate SAML login
|
||||
user_data = _simulate_saml_login(test_email, admin_user)
|
||||
|
||||
# Verify account_type is set to standard after conversion
|
||||
assert (
|
||||
user_data["account_type"] == AccountType.STANDARD.value
|
||||
), f"Expected account_type='{AccountType.STANDARD.value}', got '{user_data['account_type']}'"
|
||||
|
||||
# Verify role is BASIC after conversion
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify the user was assigned to the Basic default group
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), f"Converted user '{test_email}' not found in Basic default group"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_normal_signin_assigns_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that a brand-new user signing in via SAML for the first time
|
||||
is created with the correct role, account_type, and group membership.
|
||||
|
||||
This validates that normal SAML sign-in (not an upgrade from
|
||||
SLACK_USER/EXT_PERM_USER) correctly:
|
||||
1. Creates the user with role=BASIC and account_type=STANDARD
|
||||
2. Assigns the user to the Basic default group
|
||||
"""
|
||||
# First user becomes admin
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# New user signs in via SAML (no prior account)
|
||||
new_email = "new_saml_user@example.com"
|
||||
user_data = _simulate_saml_login(new_email, admin_user)
|
||||
|
||||
# Verify role and account_type
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
assert user_data["account_type"] == AccountType.STANDARD.value
|
||||
|
||||
# Verify user is in the Basic default group
|
||||
assert new_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), f"New SAML user '{new_email}' not found in Basic default group"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_user_conversion_restores_group_membership(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that SAML login restores Basic group membership when converting
|
||||
a non-authenticated user (EXT_PERM_USER or SLACK_USER) to BASIC.
|
||||
|
||||
Group membership implies 'basic' permission (verified by
|
||||
test_new_group_gets_basic_permission).
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# --- EXT_PERM_USER path ---
|
||||
ext_email = "ext_perm_perms@example.com"
|
||||
ext_user = UserManager.create(email=ext_email)
|
||||
assert ext_email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=ext_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert ext_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
user_data = _simulate_saml_login(ext_email, admin_user)
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
assert ext_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "EXT_PERM_USER should be back in Basic group after SAML conversion"
|
||||
|
||||
# --- SLACK_USER path ---
|
||||
slack_email = "slack_perms@example.com"
|
||||
slack_user = UserManager.create(email=slack_email)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert slack_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
user_data = _simulate_saml_login(slack_email, admin_user)
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
assert slack_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "SLACK_USER should be back in Basic group after SAML conversion"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_round_trip_group_lifecycle(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test the full round-trip: BASIC -> EXT_PERM -> SAML(BASIC) -> EXT_PERM -> SAML(BASIC).
|
||||
|
||||
Verifies group membership is correctly removed and restored at each transition.
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
test_email = "roundtrip@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
|
||||
# Step 1: BASIC user is in Basic group
|
||||
assert test_email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# Step 2: Downgrade to EXT_PERM_USER — loses Basic group
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert test_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# Step 3: SAML login — converts back to BASIC, regains Basic group
|
||||
_simulate_saml_login(test_email, admin_user)
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "Should be in Basic group after first SAML conversion"
|
||||
|
||||
# Step 4: Downgrade again
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert test_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# Step 5: SAML login again — should still restore correctly
|
||||
_simulate_saml_login(test_email, admin_user)
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "Should be in Basic group after second SAML conversion"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_slack_user_conversion_sets_account_type_and_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that SAML login sets account_type to STANDARD and assigns Basic group
|
||||
when converting a SLACK_USER (BOT account_type).
|
||||
|
||||
Mirrors test_saml_user_conversion_sets_account_type_and_group but for
|
||||
SLACK_USER instead of EXT_PERM_USER, and additionally verifies permissions.
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
test_email = "slack_convert@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert UserManager.is_role(test_user, UserRole.SLACK_USER)
|
||||
|
||||
# SAML login
|
||||
user_data = _simulate_saml_login(test_email, admin_user)
|
||||
|
||||
# Verify account_type and role
|
||||
assert (
|
||||
user_data["account_type"] == AccountType.STANDARD.value
|
||||
), f"Expected STANDARD, got {user_data['account_type']}"
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify Basic group membership (implies 'basic' permission)
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), f"Converted SLACK_USER '{test_email}' not found in Basic default group"
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Integration tests for permission propagation across auth-triggered group changes.
|
||||
|
||||
These tests verify that effective permissions (via /me/permissions) actually
|
||||
propagate when users are added/removed from default groups through role changes.
|
||||
Custom permission grant tests will be added once the permission grant API is built.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _get_basic_group_member_emails(admin_user: DATestUser) -> set[str]:
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.is_default and g.name == "Basic"), None
|
||||
)
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
return {u.email for u in basic_group.users}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission propagation tests require enterprise features",
|
||||
)
|
||||
def test_basic_permission_granted_on_registration(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""New users should get 'basic' permission through default group assignment."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
basic_user: DATestUser = UserManager.create(email="basic@example.com")
|
||||
|
||||
# Admin should have permissions from Admin group
|
||||
admin_perms = UserManager.get_permissions(admin_user)
|
||||
assert "basic" in admin_perms
|
||||
|
||||
# Basic user should have 'basic' from Basic default group
|
||||
basic_perms = UserManager.get_permissions(basic_user)
|
||||
assert "basic" in basic_perms
|
||||
|
||||
# Verify group membership matches
|
||||
assert basic_user.email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission propagation tests require enterprise features",
|
||||
)
|
||||
def test_role_downgrade_removes_basic_group_and_permission(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Downgrading to EXT_PERM_USER or SLACK_USER should remove from Basic group."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# --- EXT_PERM_USER ---
|
||||
ext_user: DATestUser = UserManager.create(email="ext@example.com")
|
||||
assert ext_user.email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=ext_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert ext_user.email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# --- SLACK_USER ---
|
||||
slack_user: DATestUser = UserManager.create(email="slack@example.com")
|
||||
assert slack_user.email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert slack_user.email not in _get_basic_group_member_emails(admin_user)
|
||||
@@ -21,8 +21,15 @@ import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
@@ -44,13 +51,6 @@ def scim_token(idp_style: str) -> str:
|
||||
per IdP-style run and reuse. Uses UserManager directly to avoid
|
||||
fixture-scope conflicts with the function-scoped admin_user fixture.
|
||||
"""
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
try:
|
||||
admin = UserManager.create(name=ADMIN_USER_NAME)
|
||||
except Exception:
|
||||
@@ -550,3 +550,145 @@ def test_patch_add_duplicate_member_is_idempotent(
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["members"]) == 1 # still just one member
|
||||
|
||||
|
||||
def test_create_group_reserved_name_admin(scim_token: str) -> None:
|
||||
"""POST /Groups with reserved name 'Admin' returns 409."""
|
||||
resp = _create_scim_group(scim_token, "Admin", external_id="ext-reserved-admin")
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_create_group_reserved_name_basic(scim_token: str) -> None:
|
||||
"""POST /Groups with reserved name 'Basic' returns 409."""
|
||||
resp = _create_scim_group(scim_token, "Basic", external_id="ext-reserved-basic")
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_replace_group_cannot_rename_to_reserved(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""PUT /Groups/{id} renaming a group to 'Admin' returns 409."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Rename To Reserved {idp_style}",
|
||||
external_id=f"ext-rtr-{idp_style}",
|
||||
).json()
|
||||
|
||||
resp = ScimClient.put(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_group_resource(
|
||||
display_name="Admin", external_id=f"ext-rtr-{idp_style}"
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_patch_rename_to_reserved_name(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} renaming a group to 'Basic' returns 409."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Rename Reserved {idp_style}",
|
||||
external_id=f"ext-prr-{idp_style}",
|
||||
).json()
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "displayName", "value": "Basic"}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_delete_reserved_group_rejected(scim_token: str) -> None:
|
||||
"""DELETE /Groups/{id} on a reserved group ('Admin') returns 409."""
|
||||
# Look up the reserved 'Admin' group via SCIM filter
|
||||
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
resources = resp.json()["Resources"]
|
||||
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
|
||||
admin_group_id = resources[0]["id"]
|
||||
|
||||
resp = ScimClient.delete(f"/Groups/{admin_group_id}", scim_token)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_scim_created_group_has_basic_permission(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""POST /Groups assigns the 'basic' permission to the group itself."""
|
||||
# Create a SCIM group (no members needed — we check the group's permissions)
|
||||
resp = _create_scim_group(
|
||||
scim_token,
|
||||
f"Basic Perm Group {idp_style}",
|
||||
external_id=f"ext-basic-perm-{idp_style}",
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
group_id = resp.json()["id"]
|
||||
|
||||
# Log in as the admin user (created by the scim_token fixture).
|
||||
admin = DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
admin = UserManager.login_as_user(admin)
|
||||
|
||||
# Verify the group itself was granted the basic permission
|
||||
perms_resp = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{group_id}/permissions",
|
||||
headers=admin.headers,
|
||||
)
|
||||
perms_resp.raise_for_status()
|
||||
perms = perms_resp.json()
|
||||
assert "basic" in perms, f"SCIM group should have 'basic' permission, got: {perms}"
|
||||
|
||||
|
||||
def test_replace_group_cannot_rename_from_reserved(scim_token: str) -> None:
|
||||
"""PUT /Groups/{id} renaming a reserved group ('Admin') to a non-reserved name returns 409."""
|
||||
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
resources = resp.json()["Resources"]
|
||||
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
|
||||
admin_group_id = resources[0]["id"]
|
||||
|
||||
resp = ScimClient.put(
|
||||
f"/Groups/{admin_group_id}",
|
||||
scim_token,
|
||||
json=_make_group_resource(
|
||||
display_name="RenamedAdmin", external_id="ext-rename-from-reserved"
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_patch_rename_from_reserved_name(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} renaming a reserved group ('Admin') returns 409."""
|
||||
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
resources = resp.json()["Resources"]
|
||||
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
|
||||
admin_group_id = resources[0]["id"]
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{admin_group_id}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "displayName", "value": "RenamedAdmin"}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
@@ -35,9 +35,16 @@ from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
@@ -211,6 +218,49 @@ def test_create_user(scim_token: str, idp_style: str) -> None:
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_create_user_default_group_and_account_type(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""SCIM-provisioned users get Basic default group and STANDARD account_type."""
|
||||
email = f"scim_defaults_{idp_style}@example.com"
|
||||
ext_id = f"ext-defaults-{idp_style}"
|
||||
resp = _create_scim_user(scim_token, email, ext_id, idp_style)
|
||||
assert resp.status_code == 201
|
||||
user_id = resp.json()["id"]
|
||||
|
||||
# --- Verify group assignment via SCIM GET ---
|
||||
get_resp = ScimClient.get(f"/Users/{user_id}", scim_token)
|
||||
assert get_resp.status_code == 200
|
||||
groups = get_resp.json().get("groups", [])
|
||||
group_names = {g["display"] for g in groups}
|
||||
assert "Basic" in group_names, f"Expected 'Basic' in groups, got {group_names}"
|
||||
assert "Admin" not in group_names, "SCIM user should not be in Admin group"
|
||||
|
||||
# --- Verify account_type via admin API ---
|
||||
admin = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
page = UserManager.get_user_page(
|
||||
user_performing_action=admin,
|
||||
search_query=email,
|
||||
)
|
||||
assert page.total_items >= 1
|
||||
scim_user_snapshot = next((u for u in page.items if u.email == email), None)
|
||||
assert (
|
||||
scim_user_snapshot is not None
|
||||
), f"SCIM user {email} not found in user listing"
|
||||
assert (
|
||||
scim_user_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Expected STANDARD, got {scim_user_snapshot.account_type}"
|
||||
|
||||
|
||||
def test_get_user(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users/{id} returns the user resource with all stored fields."""
|
||||
email = f"scim_get_{idp_style}@example.com"
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import UserGroup as UserGroupModel
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_user_gets_permissions_when_added_to_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_for_perm_test")
|
||||
basic_user: DATestUser = UserManager.create(name="basic_user_for_perm_test")
|
||||
|
||||
# basic_user starts with only "basic" from the default group
|
||||
initial_permissions = UserManager.get_permissions(basic_user)
|
||||
assert "basic" in initial_permissions
|
||||
assert "add:agents" not in initial_permissions
|
||||
|
||||
# Create a new group and add basic_user
|
||||
group = UserGroupManager.create(
|
||||
name="perm-test-group",
|
||||
user_ids=[admin_user.id, basic_user.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Grant a non-basic permission to the group and recompute
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_group = db_session.get(UserGroupModel, group.id)
|
||||
assert db_group is not None
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=db_group.id,
|
||||
permission=Permission.ADD_AGENTS,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
recompute_user_permissions__no_commit(basic_user.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
# Verify the user gained the new permission (expanded includes read:agents)
|
||||
updated_permissions = UserManager.get_permissions(basic_user)
|
||||
assert (
|
||||
"add:agents" in updated_permissions
|
||||
), f"User should have 'add:agents' after group grant, got: {updated_permissions}"
|
||||
assert (
|
||||
"read:agents" in updated_permissions
|
||||
), f"User should have implied 'read:agents', got: {updated_permissions}"
|
||||
assert "basic" in updated_permissions
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_group_permission_change_propagates_to_all_members(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_propagate")
|
||||
user_a: DATestUser = UserManager.create(name="user_a_propagate")
|
||||
user_b: DATestUser = UserManager.create(name="user_b_propagate")
|
||||
|
||||
group = UserGroupManager.create(
|
||||
name="propagate-test-group",
|
||||
user_ids=[admin_user.id, user_a.id, user_b.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Neither user should have add:agents yet
|
||||
for u in (user_a, user_b):
|
||||
assert "add:agents" not in UserManager.get_permissions(u)
|
||||
|
||||
# Grant add:agents to the group, then batch-recompute
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
grant = PermissionGrant(
|
||||
group_id=group.id,
|
||||
permission=Permission.ADD_AGENTS,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
db_session.add(grant)
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
# Both users should now have the permission (plus implied read:agents)
|
||||
for u in (user_a, user_b):
|
||||
perms = UserManager.get_permissions(u)
|
||||
assert "add:agents" in perms, f"{u.id} missing add:agents: {perms}"
|
||||
assert "read:agents" in perms, f"{u.id} missing implied read:agents: {perms}"
|
||||
|
||||
# Soft-delete the grant and recompute — permission should be removed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_grant = (
|
||||
db_session.query(PermissionGrant)
|
||||
.filter_by(group_id=group.id, permission=Permission.ADD_AGENTS)
|
||||
.first()
|
||||
)
|
||||
assert db_grant is not None
|
||||
db_grant.is_deleted = True
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
for u in (user_a, user_b):
|
||||
perms = UserManager.get_permissions(u)
|
||||
assert "add:agents" not in perms, f"{u.id} still has add:agents: {perms}"
|
||||
@@ -0,0 +1,30 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_new_group_gets_basic_permission(reset: None) -> None: # noqa: ARG001
|
||||
admin_user: DATestUser = UserManager.create(name="admin_for_basic_perm")
|
||||
|
||||
user_group = UserGroupManager.create(
|
||||
name="basic-perm-test-group",
|
||||
user_ids=[admin_user.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
permissions = UserGroupManager.get_permissions(
|
||||
user_group=user_group,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert (
|
||||
"basic" in permissions
|
||||
), f"New group should have 'basic' permission, got: {permissions}"
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Integration tests for default group assignment on user registration.
|
||||
|
||||
Verifies that:
|
||||
- The first registered user is assigned to the Admin default group
|
||||
- Subsequent registered users are assigned to the Basic default group
|
||||
- account_type is set to STANDARD for email/password registrations
|
||||
"""
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_default_group_assignment_on_registration(reset: None) -> None: # noqa: ARG001
|
||||
# Register first user — should become admin
|
||||
admin_user: DATestUser = UserManager.create(name="first_user")
|
||||
assert admin_user.role == UserRole.ADMIN
|
||||
|
||||
# Register second user — should become basic
|
||||
basic_user: DATestUser = UserManager.create(name="second_user")
|
||||
assert basic_user.role == UserRole.BASIC
|
||||
|
||||
# Fetch all groups including default ones
|
||||
all_groups = UserGroupManager.get_all(
|
||||
user_performing_action=admin_user,
|
||||
include_default=True,
|
||||
)
|
||||
|
||||
# Find the default Admin and Basic groups
|
||||
admin_group = next(
|
||||
(g for g in all_groups if g.name == "Admin" and g.is_default), None
|
||||
)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.name == "Basic" and g.is_default), None
|
||||
)
|
||||
assert admin_group is not None, "Admin default group not found"
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
|
||||
# Verify admin user is in Admin group and NOT in Basic group
|
||||
admin_group_user_ids = {str(u.id) for u in admin_group.users}
|
||||
basic_group_user_ids = {str(u.id) for u in basic_group.users}
|
||||
|
||||
assert (
|
||||
admin_user.id in admin_group_user_ids
|
||||
), "First user should be in Admin default group"
|
||||
assert (
|
||||
admin_user.id not in basic_group_user_ids
|
||||
), "First user should NOT be in Basic default group"
|
||||
|
||||
# Verify basic user is in Basic group and NOT in Admin group
|
||||
assert (
|
||||
basic_user.id in basic_group_user_ids
|
||||
), "Second user should be in Basic default group"
|
||||
assert (
|
||||
basic_user.id not in admin_group_user_ids
|
||||
), "Second user should NOT be in Admin default group"
|
||||
|
||||
# Verify account_type is STANDARD for both users via user listing API
|
||||
paginated_result = UserManager.get_user_page(
|
||||
user_performing_action=admin_user,
|
||||
page_num=0,
|
||||
page_size=10,
|
||||
)
|
||||
users_by_id = {str(u.id): u for u in paginated_result.items}
|
||||
|
||||
admin_snapshot = users_by_id.get(admin_user.id)
|
||||
basic_snapshot = users_by_id.get(basic_user.id)
|
||||
assert admin_snapshot is not None, "Admin user not found in user listing"
|
||||
assert basic_snapshot is not None, "Basic user not found in user listing"
|
||||
|
||||
assert (
|
||||
admin_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Admin user account_type should be STANDARD, got {admin_snapshot.account_type}"
|
||||
assert (
|
||||
basic_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Basic user account_type should be STANDARD, got {basic_snapshot.account_type}"
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Integration tests for password signup upgrade paths.
|
||||
|
||||
Verifies that when a BOT or EXT_PERM_USER user signs up via email/password:
|
||||
- Their account_type is upgraded to STANDARD
|
||||
- They are assigned to the Basic default group
|
||||
- They gain the correct effective permissions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _get_default_group_member_emails(
|
||||
admin_user: DATestUser,
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get the set of emails of all members in a named default group."""
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
matched = [g for g in all_groups if g.is_default and g.name == group_name]
|
||||
assert matched, f"Default group '{group_name}' not found"
|
||||
return {u.email for u in matched[0].users}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target_role",
|
||||
[UserRole.EXT_PERM_USER, UserRole.SLACK_USER],
|
||||
ids=["ext_perm_user", "slack_user"],
|
||||
)
|
||||
def test_password_signup_upgrade(
|
||||
reset: None, # noqa: ARG001
|
||||
target_role: UserRole,
|
||||
) -> None:
|
||||
"""When a non-web user signs up via email/password, they should be
|
||||
upgraded to STANDARD account_type and assigned to the Basic default group."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
test_email = f"{target_role.value}_upgrade@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
|
||||
test_user = UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=target_role,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
# Verify user was removed from Basic group after downgrade
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert (
|
||||
test_email not in basic_emails
|
||||
), f"{target_role.value} should not be in Basic default group"
|
||||
|
||||
# Re-register with the same email — triggers the password signup upgrade
|
||||
upgraded_user = UserManager.create(email=test_email)
|
||||
|
||||
assert upgraded_user.role == UserRole.BASIC
|
||||
|
||||
paginated = UserManager.get_user_page(
|
||||
user_performing_action=admin_user,
|
||||
page_num=0,
|
||||
page_size=10,
|
||||
)
|
||||
user_snapshot = next(
|
||||
(u for u in paginated.items if str(u.id) == upgraded_user.id), None
|
||||
)
|
||||
assert user_snapshot is not None
|
||||
assert (
|
||||
user_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Expected STANDARD, got {user_snapshot.account_type}"
|
||||
|
||||
# Verify user is now in the Basic default group
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert (
|
||||
test_email in basic_emails
|
||||
), f"Upgraded user '{test_email}' not found in Basic default group"
|
||||
|
||||
|
||||
def test_password_signup_upgrade_propagates_permissions(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When an EXT_PERM_USER or SLACK_USER signs up via password, they should
|
||||
gain the 'basic' permission through the Basic default group assignment."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# --- EXT_PERM_USER path ---
|
||||
ext_email = "ext_perms_check@example.com"
|
||||
ext_user = UserManager.create(email=ext_email)
|
||||
|
||||
initial_perms = UserManager.get_permissions(ext_user)
|
||||
assert "basic" in initial_perms
|
||||
|
||||
ext_user = UserManager.set_role(
|
||||
user_to_set=ext_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert ext_email not in basic_emails
|
||||
|
||||
upgraded = UserManager.create(email=ext_email)
|
||||
assert upgraded.role == UserRole.BASIC
|
||||
|
||||
perms = UserManager.get_permissions(upgraded)
|
||||
assert (
|
||||
"basic" in perms
|
||||
), f"Upgraded EXT_PERM_USER should have 'basic' permission, got: {perms}"
|
||||
|
||||
# --- SLACK_USER path ---
|
||||
slack_email = "slack_perms_check@example.com"
|
||||
slack_user = UserManager.create(email=slack_email)
|
||||
|
||||
slack_user = UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert slack_email not in basic_emails
|
||||
|
||||
upgraded = UserManager.create(email=slack_email)
|
||||
assert upgraded.role == UserRole.BASIC
|
||||
|
||||
perms = UserManager.get_permissions(upgraded)
|
||||
assert (
|
||||
"basic" in perms
|
||||
), f"Upgraded SLACK_USER should have 'basic' permission, got: {perms}"
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Integration tests for default group reconciliation on user reactivation.
|
||||
|
||||
Verifies that:
|
||||
- A deactivated user retains default group membership after reactivation
|
||||
- Reactivation via the admin API reconciles missing group membership
|
||||
"""
|
||||
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _get_default_group_member_emails(
|
||||
admin_user: DATestUser,
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get the set of emails of all members in a named default group."""
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
matched = [g for g in all_groups if g.is_default and g.name == group_name]
|
||||
assert matched, f"Default group '{group_name}' not found"
|
||||
return {u.email for u in matched[0].users}
|
||||
|
||||
|
||||
def test_reactivated_user_retains_default_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Deactivating and reactivating a user should preserve their
|
||||
default group membership."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
basic_user: DATestUser = UserManager.create(name="basic_user")
|
||||
|
||||
# Verify user is in Basic group initially
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert basic_user.email in basic_emails
|
||||
|
||||
# Deactivate the user
|
||||
UserManager.set_status(
|
||||
user_to_set=basic_user,
|
||||
target_status=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Reactivate the user
|
||||
UserManager.set_status(
|
||||
user_to_set=basic_user,
|
||||
target_status=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify user is still in Basic group after reactivation
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert (
|
||||
basic_user.email in basic_emails
|
||||
), "Reactivated user should still be in Basic default group"
|
||||
176
backend/tests/unit/onyx/auth/test_permissions.py
Normal file
176
backend/tests/unit/onyx/auth/test_permissions.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Unit tests for onyx.auth.permissions — pure logic and FastAPI dependency.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.permissions import ALL_PERMISSIONS
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.permissions import resolve_effective_permissions
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_effective_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveEffectivePermissions:
|
||||
def test_empty_set(self) -> None:
|
||||
assert resolve_effective_permissions(set()) == set()
|
||||
|
||||
def test_basic_no_implications(self) -> None:
|
||||
result = resolve_effective_permissions({"basic"})
|
||||
assert result == {"basic"}
|
||||
|
||||
def test_single_implication(self) -> None:
|
||||
result = resolve_effective_permissions({"add:agents"})
|
||||
assert result == {"add:agents", "read:agents"}
|
||||
|
||||
def test_manage_agents_implies_add_and_read(self) -> None:
|
||||
"""manage:agents directly maps to {add:agents, read:agents}."""
|
||||
result = resolve_effective_permissions({"manage:agents"})
|
||||
assert result == {"manage:agents", "add:agents", "read:agents"}
|
||||
|
||||
def test_manage_connectors_chain(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:connectors"})
|
||||
assert result == {"manage:connectors", "add:connectors", "read:connectors"}
|
||||
|
||||
def test_manage_document_sets(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:document_sets"})
|
||||
assert result == {
|
||||
"manage:document_sets",
|
||||
"read:document_sets",
|
||||
"read:connectors",
|
||||
}
|
||||
|
||||
def test_manage_user_groups_implies_all_reads(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:user_groups"})
|
||||
assert result == {
|
||||
"manage:user_groups",
|
||||
"read:connectors",
|
||||
"read:document_sets",
|
||||
"read:agents",
|
||||
"read:users",
|
||||
}
|
||||
|
||||
def test_admin_override(self) -> None:
|
||||
result = resolve_effective_permissions({"admin"})
|
||||
assert result == set(ALL_PERMISSIONS)
|
||||
|
||||
def test_admin_with_others(self) -> None:
|
||||
result = resolve_effective_permissions({"admin", "basic"})
|
||||
assert result == set(ALL_PERMISSIONS)
|
||||
|
||||
def test_multi_group_union(self) -> None:
|
||||
result = resolve_effective_permissions(
|
||||
{"add:agents", "manage:connectors", "basic"}
|
||||
)
|
||||
assert result == {
|
||||
"basic",
|
||||
"add:agents",
|
||||
"read:agents",
|
||||
"manage:connectors",
|
||||
"add:connectors",
|
||||
"read:connectors",
|
||||
}
|
||||
|
||||
def test_toggle_permission_no_implications(self) -> None:
|
||||
result = resolve_effective_permissions({"read:agent_analytics"})
|
||||
assert result == {"read:agent_analytics"}
|
||||
|
||||
def test_all_permissions_for_admin(self) -> None:
|
||||
result = resolve_effective_permissions({"admin"})
|
||||
assert len(result) == len(ALL_PERMISSIONS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_effective_permissions (expands implied at read time)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetEffectivePermissions:
|
||||
def test_expands_implied_permissions(self) -> None:
|
||||
"""Column stores only granted; get_effective_permissions expands implied."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["add:agents"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == {Permission.ADD_AGENTS, Permission.READ_AGENTS}
|
||||
|
||||
def test_admin_expands_to_all(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["admin"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == set(Permission)
|
||||
|
||||
def test_basic_stays_basic(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["basic"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == {Permission.BASIC_ACCESS}
|
||||
|
||||
def test_empty_column(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = []
|
||||
result = get_effective_permissions(user)
|
||||
assert result == set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# require_permission (FastAPI dependency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRequirePermission:
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_bypass(self) -> None:
|
||||
"""Admin stored in column should pass any permission check."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["admin"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_required_permission(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["manage:connectors"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implied_permission_passes(self) -> None:
|
||||
"""manage:connectors implies read:connectors at read time."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["manage:connectors"]
|
||||
|
||||
dep = require_permission(Permission.READ_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_permission_raises(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["basic"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await dep(user=user)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_permissions_fails(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = []
|
||||
|
||||
dep = require_permission(Permission.BASIC_ACCESS)
|
||||
with pytest.raises(OnyxError):
|
||||
await dep(user=user)
|
||||
29
backend/tests/unit/onyx/auth/test_user_create_schema.py
Normal file
29
backend/tests/unit/onyx/auth/test_user_create_schema.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Unit tests for UserCreate schema dict methods.
|
||||
|
||||
Verifies that account_type is always included in create_update_dict
|
||||
and create_update_dict_superuser.
|
||||
"""
|
||||
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.db.enums import AccountType
|
||||
|
||||
|
||||
def test_create_update_dict_includes_default_account_type() -> None:
|
||||
uc = UserCreate(email="a@b.com", password="secret123")
|
||||
d = uc.create_update_dict()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
|
||||
|
||||
def test_create_update_dict_includes_explicit_account_type() -> None:
|
||||
uc = UserCreate(
|
||||
email="a@b.com", password="secret123", account_type=AccountType.SERVICE_ACCOUNT
|
||||
)
|
||||
d = uc.create_update_dict()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
|
||||
|
||||
def test_create_update_dict_superuser_includes_account_type() -> None:
|
||||
uc = UserCreate(email="a@b.com", password="secret123")
|
||||
d = uc.create_update_dict_superuser()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
173
backend/tests/unit/onyx/chat/test_emitter.py
Normal file
173
backend/tests/unit/onyx/chat/test_emitter.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Unit tests for the Emitter class.
|
||||
|
||||
All tests use the streaming mode (merged_queue required). Emitter has a single
|
||||
code path — no standalone bus.
|
||||
"""
|
||||
|
||||
import queue
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _placement(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
)
|
||||
|
||||
|
||||
def _packet(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Packet:
|
||||
"""Build a minimal valid packet with an OverallStop payload."""
|
||||
return Packet(
|
||||
placement=_placement(turn_index, tab_index, sub_turn_index),
|
||||
obj=OverallStop(stop_reason="test"),
|
||||
)
|
||||
|
||||
|
||||
def _make_emitter(model_idx: int = 0) -> tuple["Emitter", "queue.Queue"]:
|
||||
"""Return (emitter, queue) wired together."""
|
||||
mq: queue.Queue = queue.Queue()
|
||||
return Emitter(merged_queue=mq, model_idx=model_idx), mq
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queue routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterQueueRouting:
|
||||
def test_emit_lands_on_merged_queue(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet())
|
||||
assert not mq.empty()
|
||||
|
||||
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=1)
|
||||
emitter.emit(_packet())
|
||||
item = mq.get_nowait()
|
||||
assert isinstance(item, tuple)
|
||||
assert len(item) == 2
|
||||
|
||||
def test_multiple_packets_delivered_fifo(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
p1 = _packet(turn_index=0)
|
||||
p2 = _packet(turn_index=1)
|
||||
emitter.emit(p1)
|
||||
emitter.emit(p2)
|
||||
_, t1 = mq.get_nowait()
|
||||
_, t2 = mq.get_nowait()
|
||||
assert t1.placement.turn_index == 0
|
||||
assert t2.placement.turn_index == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# model_index tagging
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterModelIndexTagging:
|
||||
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
|
||||
"""N=1: default model_idx=0, so packet gets model_index=0."""
|
||||
emitter, mq = _make_emitter(model_idx=0)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 0
|
||||
|
||||
def test_model_idx_one_tags_packet(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=1)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 1
|
||||
|
||||
def test_model_idx_two_tags_packet(self) -> None:
|
||||
"""Boundary: third model in a 3-model run."""
|
||||
emitter, mq = _make_emitter(model_idx=2)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queue key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterQueueKey:
|
||||
def test_key_equals_model_idx(self) -> None:
|
||||
"""Drain loop uses the key to route packets; it must match model_idx."""
|
||||
emitter, mq = _make_emitter(model_idx=2)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 2
|
||||
|
||||
def test_n1_key_is_zero(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=0)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Placement field preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterPlacementPreservation:
|
||||
def test_turn_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(turn_index=5))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.turn_index == 5
|
||||
|
||||
def test_tab_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(tab_index=3))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.tab_index == 3
|
||||
|
||||
def test_sub_turn_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(sub_turn_index=2))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index == 2
|
||||
|
||||
def test_sub_turn_index_none_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(sub_turn_index=None))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index is None
|
||||
|
||||
def test_packet_obj_is_not_modified(self) -> None:
|
||||
"""The payload object must survive tagging untouched."""
|
||||
emitter, mq = _make_emitter()
|
||||
original_obj = OverallStop(stop_reason="sentinel")
|
||||
pkt = Packet(placement=_placement(), obj=original_obj)
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.obj is original_obj
|
||||
|
||||
def test_different_obj_types_are_handled(self) -> None:
|
||||
"""Any valid PacketObj type passes through correctly."""
|
||||
emitter, mq = _make_emitter()
|
||||
pkt = Packet(placement=_placement(), obj=ReasoningStart())
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert isinstance(tagged.obj, ReasoningStart)
|
||||
754
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
754
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
@@ -0,0 +1,754 @@
|
||||
"""Unit tests for multi-model streaming validation and DB helpers.
|
||||
|
||||
These are pure unit tests — no real database or LLM calls required.
|
||||
The validation logic in handle_multi_model_stream fires before any external
|
||||
calls, so we can trigger it with lightweight mocks.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _restore_ee_version() -> Generator[None, None, None]:
|
||||
"""Reset EE global state after each test.
|
||||
|
||||
Importing onyx.chat.process_message triggers set_is_ee_based_on_env_variable()
|
||||
(via the celery client import chain). Without this fixture, the EE flag stays
|
||||
True for the rest of the session and breaks unrelated tests that mock Confluence
|
||||
or other connectors and assume EE is disabled.
|
||||
"""
|
||||
original = global_version._is_ee
|
||||
yield
|
||||
global_version._is_ee = original
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(**kwargs: Any) -> SendMessageRequest:
|
||||
defaults: dict[str, Any] = {
|
||||
"message": "hello",
|
||||
"chat_session_id": uuid4(),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SendMessageRequest(**defaults)
|
||||
|
||||
|
||||
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
|
||||
return LLMOverride(model_provider=provider, model_version=version)
|
||||
|
||||
|
||||
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
|
||||
"""Return the first item yielded by handle_multi_model_stream."""
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
|
||||
user = MagicMock()
|
||||
user.is_anonymous = False
|
||||
user.email = "test@example.com"
|
||||
db = MagicMock()
|
||||
|
||||
gen = handle_multi_model_stream(req, user, db, overrides)
|
||||
return next(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_multi_model_stream — validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunMultiModelStreamValidation:
|
||||
def test_single_override_yields_error(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_four_overrides_yields_error(self) -> None:
|
||||
"""4 overrides exceeds maximum — yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_zero_overrides_yields_error(self) -> None:
|
||||
"""Empty override list yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(req, [])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_deep_research_yields_error(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
|
||||
req = _make_request(deep_research=True)
|
||||
result = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "not supported" in result.error
|
||||
|
||||
def test_exactly_two_overrides_is_minimum(self) -> None:
|
||||
"""Boundary: 1 override yields error, 2 overrides passes validation."""
|
||||
req = _make_request()
|
||||
# 1 override must yield a StreamingError
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(
|
||||
result, StreamingError
|
||||
), "1 override should yield StreamingError"
|
||||
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
|
||||
# missing session, that's OK — validation itself passed)
|
||||
try:
|
||||
result2 = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
if isinstance(result2, StreamingError) and "2-3" in result2.error:
|
||||
pytest.fail(
|
||||
f"2 overrides should pass validation, got StreamingError: {result2.error}"
|
||||
)
|
||||
except Exception:
|
||||
pass # Any non-validation error means validation passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_preferred_response — validation (mocked db)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetPreferredResponseValidation:
|
||||
def test_user_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
db.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=999, preferred_assistant_message_id=1
|
||||
)
|
||||
|
||||
def test_wrong_message_type(self) -> None:
|
||||
"""Cannot set preferred response on a non-USER message."""
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.ASSISTANT # wrong type
|
||||
|
||||
db.get.return_value = user_msg
|
||||
|
||||
with pytest.raises(ValueError, match="not a user message"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
# First call returns user_msg, second call (for assistant) returns None
|
||||
db.get.side_effect = [user_msg, None]
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_not_child_of_user(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 999 # different parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
with pytest.raises(ValueError, match="not a child"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_valid_call_sets_preferred_response_id(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 1 # correct parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
|
||||
|
||||
assert user_msg.preferred_response_id == 2
|
||||
assert user_msg.latest_child_message_id == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMOverride — display_name field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMOverrideDisplayName:
|
||||
def test_display_name_defaults_none(self) -> None:
|
||||
override = LLMOverride(model_provider="openai", model_version="gpt-4")
|
||||
assert override.display_name is None
|
||||
|
||||
def test_display_name_set(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="openai",
|
||||
model_version="gpt-4",
|
||||
display_name="GPT-4 Turbo",
|
||||
)
|
||||
assert override.display_name == "GPT-4 Turbo"
|
||||
|
||||
def test_display_name_serializes(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="anthropic",
|
||||
model_version="claude-opus-4-6",
|
||||
display_name="Claude Opus",
|
||||
)
|
||||
d = override.model_dump()
|
||||
assert d["display_name"] == "Claude Opus"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_models — drain loop behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_setup(n_models: int = 1) -> MagicMock:
|
||||
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
|
||||
setup = MagicMock()
|
||||
setup.llms = [MagicMock() for _ in range(n_models)]
|
||||
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
|
||||
setup.reserved_token_count = 100
|
||||
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
|
||||
# constructors inside _run_model — must be typed correctly for Pydantic.
|
||||
setup.new_msg_req.deep_research = False
|
||||
setup.new_msg_req.internal_search_filters = None
|
||||
setup.new_msg_req.allowed_tool_ids = None
|
||||
setup.new_msg_req.include_citations = True
|
||||
setup.search_params.project_id_filter = None
|
||||
setup.search_params.persona_id_filter = None
|
||||
setup.bypass_acl = False
|
||||
setup.slack_context = None
|
||||
setup.available_files.user_file_ids = []
|
||||
setup.available_files.chat_file_ids = []
|
||||
setup.forced_tool_id = None
|
||||
setup.simple_chat_history = []
|
||||
setup.chat_session.id = uuid4()
|
||||
setup.user_message.id = None
|
||||
setup.custom_tool_additional_headers = None
|
||||
setup.mcp_headers = None
|
||||
return setup
|
||||
|
||||
|
||||
def _run_models_collect(setup: MagicMock) -> list:
|
||||
"""Drive _run_models to completion and return all yielded items."""
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
return list(_run_models(setup, MagicMock(), MagicMock()))
|
||||
|
||||
|
||||
class TestRunModels:
|
||||
"""Tests for the _run_models worker-thread drain loop.
|
||||
|
||||
All external dependencies (LLM, DB, tools) are patched out. Worker threads
|
||||
still run but return immediately since run_llm_loop is mocked.
|
||||
"""
|
||||
|
||||
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
|
||||
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
|
||||
|
||||
def emit_stop(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(stop_reason="complete"),
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert len(stops) == 1
|
||||
stop_obj = stops[0].obj
|
||||
assert isinstance(stop_obj, OverallStop)
|
||||
assert stop_obj.stop_reason == "complete"
|
||||
|
||||
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
|
||||
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 0
|
||||
|
||||
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
|
||||
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
# _model_idx is set by _run_model based on position in setup.llms
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=2))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 2
|
||||
indices = {p.placement.model_index for p in reasoning}
|
||||
assert indices == {0, 1}
|
||||
|
||||
def test_model_error_yields_streaming_error(self) -> None:
|
||||
"""An exception inside a worker thread is surfaced as a StreamingError."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("intentional test failure")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].error_code == "MODEL_ERROR"
|
||||
assert "intentional test failure" in errors[0].error
|
||||
|
||||
def test_one_model_error_does_not_stop_other_models(self) -> None:
|
||||
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
|
||||
setup = _make_setup(n_models=2)
|
||||
|
||||
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
|
||||
if kwargs["llm"] is setup.llms[0]:
|
||||
raise RuntimeError("model 0 failed")
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=fail_model_0_succeed_model_1,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(setup)
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 1
|
||||
|
||||
def test_cancellation_yields_user_cancelled_stop(self) -> None:
|
||||
"""If check_is_connected returns False, drain loop emits user_cancelled."""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(setup)
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert any(
|
||||
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
|
||||
for s in stops
|
||||
)
|
||||
|
||||
def test_stop_button_calls_completion_for_all_models(self) -> None:
|
||||
"""llm_loop_completion_handle must be called for all models when the stop button fires.
|
||||
|
||||
Regression test for the disconnect-cleanup bug: the old
|
||||
run_chat_loop_with_state_containers always called completion_callback in
|
||||
its finally block (even on disconnect) so the DB message was updated from
|
||||
the TERMINATED placeholder to a partial answer. The new _run_models must
|
||||
replicate this — otherwise the integration test
|
||||
test_send_message_disconnect_and_cleanup fails because the message stays
|
||||
as "Response was terminated prior to completion, try regenerating."
|
||||
"""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3)
|
||||
|
||||
setup = _make_setup(n_models=2)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
# Must be called once per model, not zero times
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_called_for_each_successful_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be called once per model that succeeded."""
|
||||
setup = _make_setup(n_models=2)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_not_called_for_failed_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be skipped for a model that raised."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("fail")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
mock_handle.assert_not_called()
|
||||
|
||||
def test_http_disconnect_completion_via_generator_exit(self) -> None:
|
||||
"""GeneratorExit from HTTP disconnect triggers main-thread completion.
|
||||
|
||||
When the HTTP client closes the connection, Starlette throws GeneratorExit
|
||||
into the stream generator. The finally block sets drain_done (signalling
|
||||
emitters to stop blocking), waits for workers via executor.shutdown(wait=True),
|
||||
then calls llm_loop_completion_handle for each successful model from the main
|
||||
thread.
|
||||
|
||||
This is the primary regression for test_send_message_disconnect_and_cleanup:
|
||||
the integration test disconnects mid-stream and expects the DB message to be
|
||||
updated from the TERMINATED placeholder to the real response.
|
||||
"""
|
||||
import threading
|
||||
|
||||
completion_called = threading.Event()
|
||||
|
||||
def emit_then_block_until_drain(**kwargs: Any) -> None:
|
||||
"""Emit one packet (to give the drain loop a yield point), then block
|
||||
until drain_done is set — simulating a mid-stream LLM call that exits
|
||||
promptly once the emitter signals shutdown.
|
||||
"""
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
# Block until drain_done is set by gen.close(). The Emitter's _drain_done
|
||||
# is the same Event that _run_models sets, so this unblocks promptly.
|
||||
emitter._drain_done.wait(timeout=5)
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=emit_then_block_until_drain,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
|
||||
first = next(gen)
|
||||
assert isinstance(first, Packet)
|
||||
# Simulate Starlette closing the stream on HTTP client disconnect.
|
||||
# gen.close() → GeneratorExit → finally → drain_done.set() →
|
||||
# executor.shutdown(wait=True) → main thread completes models.
|
||||
gen.close()
|
||||
|
||||
assert (
|
||||
completion_called.is_set()
|
||||
), "main thread must call completion for the successful model"
|
||||
assert mock_handle.call_count == 1
|
||||
|
||||
def test_b1_race_disconnect_handler_completes_already_finished_model(self) -> None:
|
||||
"""B1 regression: model finishes BEFORE GeneratorExit fires.
|
||||
|
||||
The worker exits _run_model before drain_done is set. When gen.close()
|
||||
fires afterward, the finally block sets drain_done, waits for workers
|
||||
(already done), then the main thread calls llm_loop_completion_handle.
|
||||
|
||||
Contrast with test_http_disconnect_completion_via_generator_exit, which
|
||||
tests the opposite ordering (worker finishes AFTER disconnect).
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
completion_called = threading.Event()
|
||||
|
||||
def emit_and_return_immediately(**kwargs: Any) -> None:
|
||||
# Emit one packet so the drain loop has something to yield, then return
|
||||
# immediately — no blocking. The worker will be done in microseconds.
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=emit_and_return_immediately,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
|
||||
first = next(gen)
|
||||
assert isinstance(first, Packet)
|
||||
|
||||
# Give the worker thread time to finish completely (emit + return +
|
||||
# finally + self-completion check). It does almost no work, so 100 ms
|
||||
# is far more than enough while still keeping the test fast.
|
||||
time.sleep(0.1)
|
||||
|
||||
# Now close — worker is already done, so else-branch handles completion.
|
||||
gen.close()
|
||||
|
||||
assert completion_called.wait(
|
||||
timeout=5
|
||||
), "disconnect handler must call completion for a model that already finished"
|
||||
assert mock_handle.call_count == 1, "completion must be called exactly once"
|
||||
|
||||
def test_stop_button_does_not_call_completion_for_errored_model(self) -> None:
|
||||
"""B2 regression: stop-button must NOT call completion for an errored model.
|
||||
|
||||
When model 0 raises an exception, its reserved ChatMessage must not be
|
||||
saved with 'stopped by user' — that message is wrong for a model that
|
||||
errored. llm_loop_completion_handle must only be called for non-errored
|
||||
models when the stop button fires.
|
||||
"""
|
||||
|
||||
def fail_model_0(**kwargs: Any) -> None:
|
||||
if kwargs["llm"] is setup.llms[0]:
|
||||
raise RuntimeError("model 0 errored")
|
||||
# Model 1: run forever (stop button fires before it finishes)
|
||||
time.sleep(10)
|
||||
|
||||
setup = _make_setup(n_models=2)
|
||||
# Return False immediately so the stop-button path fires while model 1
|
||||
# is still sleeping (model 0 has already errored by then).
|
||||
setup.check_is_connected = lambda: False
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
# Completion must NOT be called for model 0 (it errored).
|
||||
# It MAY be called for model 1 (still in-flight when stop fired).
|
||||
for call in mock_handle.call_args_list:
|
||||
assert (
|
||||
call.kwargs.get("llm") is not setup.llms[0]
|
||||
), "llm_loop_completion_handle must not be called for the errored model"
|
||||
|
||||
def test_external_state_container_used_for_model_zero(self) -> None:
|
||||
"""When provided, external_state_container is used as state_containers[0]."""
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
external = ChatStateContainer()
|
||||
setup = _make_setup(n_models=1)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
list(
|
||||
_run_models(
|
||||
setup, MagicMock(), MagicMock(), external_state_container=external
|
||||
)
|
||||
)
|
||||
|
||||
# The state_container kwarg passed to run_llm_loop must be the external one
|
||||
call_kwargs = mock_llm.call_args.kwargs
|
||||
assert call_kwargs["state_container"] is external
|
||||
176
backend/tests/unit/onyx/db/test_assign_default_groups.py
Normal file
176
backend/tests/unit/onyx/db/test_assign_default_groups.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Unit tests for assign_user_to_default_groups__no_commit in onyx.db.users.
|
||||
|
||||
Covers:
|
||||
1. Standard/service-account users get assigned to the correct default group
|
||||
2. BOT, EXT_PERM_USER, ANONYMOUS account types are skipped
|
||||
3. Missing default group raises RuntimeError
|
||||
4. Already-in-group is a no-op
|
||||
5. IntegrityError race condition is handled gracefully
|
||||
6. The function never commits the session
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
|
||||
|
||||
def _mock_user(
|
||||
account_type: AccountType = AccountType.STANDARD,
|
||||
email: str = "test@example.com",
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = uuid4()
|
||||
user.email = email
|
||||
user.account_type = account_type
|
||||
return user
|
||||
|
||||
|
||||
def _mock_group(name: str = "Basic", group_id: int = 1) -> MagicMock:
|
||||
group = MagicMock()
|
||||
group.id = group_id
|
||||
group.name = name
|
||||
group.is_default = True
|
||||
return group
|
||||
|
||||
|
||||
def _make_query_chain(first_return: object = None) -> MagicMock:
|
||||
"""Returns a mock that supports .filter(...).filter(...).first() chaining."""
|
||||
chain = MagicMock()
|
||||
chain.filter.return_value = chain
|
||||
chain.first.return_value = first_return
|
||||
return chain
|
||||
|
||||
|
||||
def _setup_db_session(
|
||||
group_result: object = None,
|
||||
membership_result: object = None,
|
||||
) -> MagicMock:
|
||||
"""Create a db_session mock that routes query(UserGroup) and query(User__UserGroup)."""
|
||||
db_session = MagicMock()
|
||||
|
||||
group_chain = _make_query_chain(group_result)
|
||||
membership_chain = _make_query_chain(membership_result)
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model is UserGroup:
|
||||
return group_chain
|
||||
if model is User__UserGroup:
|
||||
return membership_chain
|
||||
return MagicMock()
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
return db_session
|
||||
|
||||
|
||||
def test_standard_user_assigned_to_basic_group() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.STANDARD)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=False)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
added = db_session.add.call_args[0][0]
|
||||
assert isinstance(added, User__UserGroup)
|
||||
assert added.user_id == user.id
|
||||
assert added.user_group_id == group.id
|
||||
db_session.flush.assert_called_once()
|
||||
|
||||
|
||||
def test_admin_user_assigned_to_admin_group() -> None:
|
||||
group = _mock_group("Admin", group_id=2)
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.STANDARD)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=True)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
added = db_session.add.call_args[0][0]
|
||||
assert isinstance(added, User__UserGroup)
|
||||
assert added.user_group_id == group.id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"account_type",
|
||||
[AccountType.BOT, AccountType.EXT_PERM_USER, AccountType.ANONYMOUS],
|
||||
)
|
||||
def test_excluded_account_types_skipped(account_type: AccountType) -> None:
|
||||
db_session = MagicMock()
|
||||
user = _mock_user(account_type)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.query.assert_not_called()
|
||||
db_session.add.assert_not_called()
|
||||
|
||||
|
||||
def test_service_account_not_skipped() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.SERVICE_ACCOUNT)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=False)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
|
||||
|
||||
def test_missing_default_group_raises_error() -> None:
|
||||
db_session = _setup_db_session(group_result=None)
|
||||
user = _mock_user()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Default group .* not found"):
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
|
||||
def test_already_in_group_is_noop() -> None:
|
||||
group = _mock_group("Basic")
|
||||
existing_membership = MagicMock()
|
||||
db_session = _setup_db_session(
|
||||
group_result=group, membership_result=existing_membership
|
||||
)
|
||||
user = _mock_user()
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.add.assert_not_called()
|
||||
db_session.begin_nested.assert_not_called()
|
||||
|
||||
|
||||
def test_integrity_error_race_condition_handled() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
db_session.flush.side_effect = IntegrityError(None, None, Exception("duplicate"))
|
||||
user = _mock_user()
|
||||
|
||||
# Should not raise
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
savepoint.rollback.assert_called_once()
|
||||
|
||||
|
||||
def test_no_commit_called_on_successful_assignment() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user()
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.commit.assert_not_called()
|
||||
198
backend/tests/unit/onyx/file_processing/test_xlsx_to_text.py
Normal file
198
backend/tests/unit/onyx/file_processing/test_xlsx_to_text.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import io
|
||||
from typing import cast
|
||||
|
||||
import openpyxl
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
|
||||
from onyx.file_processing.extract_file_text import xlsx_to_text
|
||||
|
||||
|
||||
def _make_xlsx(sheets: dict[str, list[list[str]]]) -> io.BytesIO:
|
||||
"""Create an in-memory xlsx file from a dict of sheet_name -> matrix of strings."""
|
||||
wb = openpyxl.Workbook()
|
||||
if wb.active is not None:
|
||||
wb.remove(cast(Worksheet, wb.active))
|
||||
for sheet_name, rows in sheets.items():
|
||||
ws = wb.create_sheet(title=sheet_name)
|
||||
for row in rows:
|
||||
ws.append(row)
|
||||
buf = io.BytesIO()
|
||||
wb.save(buf)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
class TestXlsxToText:
|
||||
def test_single_sheet_basic(self) -> None:
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["Name", "Age"],
|
||||
["Alice", "30"],
|
||||
["Bob", "25"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
assert len(lines) == 3
|
||||
assert "Name" in lines[0]
|
||||
assert "Age" in lines[0]
|
||||
assert "Alice" in lines[1]
|
||||
assert "30" in lines[1]
|
||||
assert "Bob" in lines[2]
|
||||
|
||||
def test_multiple_sheets_separated(self) -> None:
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [["a", "b"]],
|
||||
"Sheet2": [["c", "d"]],
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
# TEXT_SECTION_SEPARATOR is "\n\n"
|
||||
assert "\n\n" in result
|
||||
parts = result.split("\n\n")
|
||||
assert any("a" in p for p in parts)
|
||||
assert any("c" in p for p in parts)
|
||||
|
||||
def test_empty_cells(self) -> None:
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["a", "", "b"],
|
||||
["", "c", ""],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
assert len(lines) == 2
|
||||
|
||||
def test_commas_in_cells_are_quoted(self) -> None:
|
||||
"""Cells containing commas should be quoted in CSV output."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["hello, world", "normal"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
assert '"hello, world"' in result
|
||||
|
||||
def test_empty_workbook(self) -> None:
|
||||
xlsx = _make_xlsx({"Sheet1": []})
|
||||
result = xlsx_to_text(xlsx)
|
||||
assert result.strip() == ""
|
||||
|
||||
def test_long_empty_row_run_capped(self) -> None:
|
||||
"""Runs of >2 empty rows should be capped to 2."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["header"],
|
||||
[""],
|
||||
[""],
|
||||
[""],
|
||||
[""],
|
||||
["data"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
# 4 empty rows capped to 2, so: header + 2 empty + data = 4 lines
|
||||
assert len(lines) == 4
|
||||
assert "header" in lines[0]
|
||||
assert "data" in lines[-1]
|
||||
|
||||
def test_long_empty_col_run_capped(self) -> None:
|
||||
"""Runs of >2 empty columns should be capped to 2."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["a", "", "", "", "b"],
|
||||
["c", "", "", "", "d"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
assert len(lines) == 2
|
||||
# Each row should have 4 fields (a + 2 empty + b), not 5
|
||||
# csv format: a,,,b (3 commas = 4 fields)
|
||||
first_line = lines[0].strip()
|
||||
# Count commas to verify column reduction
|
||||
assert first_line.count(",") == 3
|
||||
|
||||
def test_short_empty_runs_kept(self) -> None:
|
||||
"""Runs of <=2 empty rows/cols should be preserved."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["a", "b"],
|
||||
["", ""],
|
||||
["", ""],
|
||||
["c", "d"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
# All 4 rows preserved (2 empty rows <= threshold)
|
||||
assert len(lines) == 4
|
||||
|
||||
def test_bad_zip_file_returns_empty(self) -> None:
|
||||
bad_file = io.BytesIO(b"not a zip file")
|
||||
result = xlsx_to_text(bad_file, file_name="test.xlsx")
|
||||
assert result == ""
|
||||
|
||||
def test_bad_zip_tilde_file_returns_empty(self) -> None:
|
||||
bad_file = io.BytesIO(b"not a zip file")
|
||||
result = xlsx_to_text(bad_file, file_name="~$temp.xlsx")
|
||||
assert result == ""
|
||||
|
||||
def test_large_sparse_sheet(self) -> None:
|
||||
"""A sheet with data, a big empty gap, and more data — gap is capped to 2."""
|
||||
rows: list[list[str]] = [["row1_data"]]
|
||||
rows.extend([[""] for _ in range(10)])
|
||||
rows.append(["row2_data"])
|
||||
xlsx = _make_xlsx({"Sheet1": rows})
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
# 10 empty rows capped to 2: row1_data + 2 empty + row2_data = 4
|
||||
assert len(lines) == 4
|
||||
assert "row1_data" in lines[0]
|
||||
assert "row2_data" in lines[-1]
|
||||
|
||||
def test_quotes_in_cells(self) -> None:
|
||||
"""Cells containing quotes should be properly escaped."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
['say "hello"', "normal"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
# csv.writer escapes quotes by doubling them
|
||||
assert '""hello""' in result
|
||||
|
||||
def test_each_row_is_separate_line(self) -> None:
|
||||
"""Each row should produce its own line (regression for writerow vs writerows)."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["r1c1", "r1c2"],
|
||||
["r2c1", "r2c2"],
|
||||
["r3c1", "r3c2"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
assert len(lines) == 3
|
||||
assert "r1c1" in lines[0] and "r1c2" in lines[0]
|
||||
assert "r2c1" in lines[1] and "r2c2" in lines[1]
|
||||
assert "r3c1" in lines[2] and "r3c2" in lines[2]
|
||||
@@ -2,6 +2,7 @@ import threading
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -12,8 +13,13 @@ from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.hooks.points.document_ingestion import DocumentIngestionResponse
|
||||
from onyx.hooks.points.document_ingestion import DocumentIngestionSection
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import _apply_document_ingestion_hook
|
||||
from onyx.indexing.indexing_pipeline import add_contextual_summaries
|
||||
from onyx.indexing.indexing_pipeline import filter_documents
|
||||
from onyx.indexing.indexing_pipeline import process_image_sections
|
||||
@@ -223,3 +229,148 @@ def test_contextual_rag(
|
||||
count += 1
|
||||
assert chunk.doc_summary == doc_summary
|
||||
assert chunk.chunk_context == chunk_context
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _apply_document_ingestion_hook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_EXECUTE_HOOK = "onyx.indexing.indexing_pipeline.execute_hook"
|
||||
|
||||
|
||||
def _make_doc(
|
||||
doc_id: str = "doc1",
|
||||
sections: list[TextSection | ImageSection] | None = None,
|
||||
) -> Document:
|
||||
if sections is None:
|
||||
sections = [TextSection(text="Hello", link="http://example.com")]
|
||||
return Document(
|
||||
id=doc_id,
|
||||
title="Test Doc",
|
||||
semantic_identifier="test-doc",
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
source=DocumentSource.FILE,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
def test_document_ingestion_hook_skipped_passes_through() -> None:
|
||||
doc = _make_doc()
|
||||
with patch(_PATCH_EXECUTE_HOOK, return_value=HookSkipped()):
|
||||
result = _apply_document_ingestion_hook([doc], MagicMock())
|
||||
assert result == [doc]
|
||||
|
||||
|
||||
def test_document_ingestion_hook_soft_failed_passes_through() -> None:
|
||||
doc = _make_doc()
|
||||
with patch(_PATCH_EXECUTE_HOOK, return_value=HookSoftFailed()):
|
||||
result = _apply_document_ingestion_hook([doc], MagicMock())
|
||||
assert result == [doc]
|
||||
|
||||
|
||||
def test_document_ingestion_hook_none_sections_drops_document() -> None:
|
||||
doc = _make_doc()
|
||||
with patch(
|
||||
_PATCH_EXECUTE_HOOK,
|
||||
return_value=DocumentIngestionResponse(
|
||||
sections=None, rejection_reason="PII detected"
|
||||
),
|
||||
):
|
||||
result = _apply_document_ingestion_hook([doc], MagicMock())
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_document_ingestion_hook_all_invalid_sections_drops_document() -> None:
|
||||
"""A non-empty list where every section has neither text nor image_file_id drops the doc."""
|
||||
doc = _make_doc()
|
||||
with patch(
|
||||
_PATCH_EXECUTE_HOOK,
|
||||
return_value=DocumentIngestionResponse(sections=[DocumentIngestionSection()]),
|
||||
):
|
||||
result = _apply_document_ingestion_hook([doc], MagicMock())
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_document_ingestion_hook_empty_sections_drops_document() -> None:
|
||||
doc = _make_doc()
|
||||
with patch(
|
||||
_PATCH_EXECUTE_HOOK,
|
||||
return_value=DocumentIngestionResponse(sections=[]),
|
||||
):
|
||||
result = _apply_document_ingestion_hook([doc], MagicMock())
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_document_ingestion_hook_rewrites_text_sections() -> None:
|
||||
doc = _make_doc(sections=[TextSection(text="original", link="http://a.com")])
|
||||
with patch(
|
||||
_PATCH_EXECUTE_HOOK,
|
||||
return_value=DocumentIngestionResponse(
|
||||
sections=[DocumentIngestionSection(text="rewritten", link="http://b.com")]
|
||||
),
|
||||
):
|
||||
result = _apply_document_ingestion_hook([doc], MagicMock())
|
||||
assert len(result) == 1
|
||||
assert len(result[0].sections) == 1
|
||||
section = result[0].sections[0]
|
||||
assert isinstance(section, TextSection)
|
||||
assert section.text == "rewritten"
|
||||
assert section.link == "http://b.com"
|
||||
|
||||
|
||||
def test_document_ingestion_hook_preserves_image_section_order() -> None:
|
||||
"""Hook receives all sections including images and controls final ordering."""
|
||||
image = ImageSection(image_file_id="img-1", link=None)
|
||||
doc = _make_doc(
|
||||
sections=cast(
|
||||
list[TextSection | ImageSection],
|
||||
[TextSection(text="original", link=None), image],
|
||||
)
|
||||
)
|
||||
# Hook moves the image before the text section
|
||||
with patch(
|
||||
_PATCH_EXECUTE_HOOK,
|
||||
return_value=DocumentIngestionResponse(
|
||||
sections=[
|
||||
DocumentIngestionSection(image_file_id="img-1", link=None),
|
||||
DocumentIngestionSection(text="rewritten", link=None),
|
||||
]
|
||||
),
|
||||
):
|
||||
result = _apply_document_ingestion_hook([doc], MagicMock())
|
||||
assert len(result) == 1
|
||||
sections = result[0].sections
|
||||
assert len(sections) == 2
|
||||
assert (
|
||||
isinstance(sections[0], ImageSection) and sections[0].image_file_id == "img-1"
|
||||
)
|
||||
assert isinstance(sections[1], TextSection) and sections[1].text == "rewritten"
|
||||
|
||||
|
||||
def test_document_ingestion_hook_mixed_batch() -> None:
|
||||
"""Drop one doc, rewrite another, pass through a third."""
|
||||
doc_drop = _make_doc(doc_id="drop")
|
||||
doc_rewrite = _make_doc(doc_id="rewrite")
|
||||
doc_skip = _make_doc(doc_id="skip")
|
||||
|
||||
def _side_effect(**kwargs: Any) -> Any:
|
||||
doc_id = kwargs["payload"]["document_id"]
|
||||
if doc_id == "drop":
|
||||
return DocumentIngestionResponse(sections=None)
|
||||
if doc_id == "rewrite":
|
||||
return DocumentIngestionResponse(
|
||||
sections=[DocumentIngestionSection(text="new text", link=None)]
|
||||
)
|
||||
return HookSkipped()
|
||||
|
||||
with patch(_PATCH_EXECUTE_HOOK, side_effect=_side_effect):
|
||||
result = _apply_document_ingestion_hook(
|
||||
[doc_drop, doc_rewrite, doc_skip], MagicMock()
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
ids = {d.id for d in result}
|
||||
assert ids == {"rewrite", "skip"}
|
||||
rewritten = next(d for d in result if d.id == "rewrite")
|
||||
assert isinstance(rewritten.sections[0], TextSection)
|
||||
assert rewritten.sections[0].text == "new text"
|
||||
|
||||
@@ -113,6 +113,7 @@ def make_db_group(**kwargs: Any) -> MagicMock:
|
||||
group.name = kwargs.get("name", "Engineering")
|
||||
group.is_up_for_deletion = kwargs.get("is_up_for_deletion", False)
|
||||
group.is_up_to_date = kwargs.get("is_up_to_date", True)
|
||||
group.is_default = kwargs.get("is_default", False)
|
||||
return group
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import UserGroupInfo
|
||||
|
||||
@@ -25,6 +26,7 @@ def _mock_user(
|
||||
user.updated_at = updated_at or datetime.datetime(
|
||||
2025, 6, 15, tzinfo=datetime.timezone.utc
|
||||
)
|
||||
user.account_type = AccountType.STANDARD
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for memory tool streaming packet emissions."""
|
||||
|
||||
from queue import Queue
|
||||
import queue
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -18,9 +18,13 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def emitter() -> Emitter:
|
||||
bus: Queue = Queue()
|
||||
return Emitter(bus)
|
||||
def emitter_queue() -> queue.Queue:
|
||||
return queue.Queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def emitter(emitter_queue: queue.Queue) -> Emitter:
|
||||
return Emitter(merged_queue=emitter_queue)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -53,24 +57,27 @@ class TestMemoryToolEmitStart:
|
||||
def test_emit_start_emits_memory_tool_start_packet(
|
||||
self,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
) -> None:
|
||||
memory_tool.emit_start(placement)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolStart)
|
||||
assert packet.placement == placement
|
||||
assert packet.placement is not None
|
||||
assert packet.placement.turn_index == placement.turn_index
|
||||
assert packet.placement.tab_index == placement.tab_index
|
||||
assert packet.placement.model_index == 0 # emitter stamps model_index=0
|
||||
|
||||
def test_emit_start_with_different_placement(
|
||||
self,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
) -> None:
|
||||
placement = Placement(turn_index=2, tab_index=1)
|
||||
memory_tool.emit_start(placement)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert packet.placement.turn_index == 2
|
||||
assert packet.placement.tab_index == 1
|
||||
|
||||
@@ -81,7 +88,7 @@ class TestMemoryToolRun:
|
||||
self,
|
||||
mock_process: MagicMock,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
override_kwargs: MemoryToolOverrideKwargs,
|
||||
) -> None:
|
||||
@@ -93,21 +100,19 @@ class TestMemoryToolRun:
|
||||
memory="User prefers Python",
|
||||
)
|
||||
|
||||
# The delta packet should be in the queue
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolDelta)
|
||||
assert packet.obj.memory_text == "User prefers Python"
|
||||
assert packet.obj.operation == "add"
|
||||
assert packet.obj.memory_id is None
|
||||
assert packet.obj.index is None
|
||||
assert packet.placement == placement
|
||||
|
||||
@patch("onyx.tools.tool_implementations.memory.memory_tool.process_memory_update")
|
||||
def test_run_emits_delta_for_update_operation(
|
||||
self,
|
||||
mock_process: MagicMock,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
override_kwargs: MemoryToolOverrideKwargs,
|
||||
) -> None:
|
||||
@@ -119,7 +124,7 @@ class TestMemoryToolRun:
|
||||
memory="User prefers light mode",
|
||||
)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolDelta)
|
||||
assert packet.obj.memory_text == "User prefers light mode"
|
||||
assert packet.obj.operation == "update"
|
||||
|
||||
@@ -193,9 +193,9 @@ hover, active, and disabled states.
|
||||
|
||||
### Disabled (`core/disabled/`)
|
||||
|
||||
Propagates disabled state via React context. `Interactive.Stateless` and `Interactive.Stateful`
|
||||
consume this automatically, so wrapping a subtree in `<Disabled disabled={true}>` disables all
|
||||
interactive descendants.
|
||||
A pure CSS wrapper that applies disabled visuals (`opacity-50`, `cursor-not-allowed`,
|
||||
`pointer-events: none`) to a single child element via Radix `Slot`. Has no React context —
|
||||
Interactive primitives and buttons manage their own disabled state via a `disabled` prop.
|
||||
|
||||
### Hoverable (`core/animations/`)
|
||||
|
||||
@@ -231,6 +231,23 @@ import { Hoverable } from "@opal/core";
|
||||
|
||||
# Best Practices
|
||||
|
||||
## 0. Size Variant Defaults
|
||||
|
||||
**When using `SizeVariants` (or any subset like `PaddingVariants`, `RoundingVariants`) as a prop
|
||||
type, always default to `"md"`.**
|
||||
|
||||
**Reason:** `"md"` is the standard middle-of-the-road preset across the design system. Consistent
|
||||
defaults make components predictable — callers only need to specify a size when they want something
|
||||
other than the norm.
|
||||
|
||||
```typescript
|
||||
// ✅ Good — default to "md"
|
||||
function MyCard({ padding = "md", rounding = "md" }: MyCardProps) { ... }
|
||||
|
||||
// ❌ Bad — arbitrary or inconsistent defaults
|
||||
function MyCard({ padding = "sm", rounding = "lg" }: MyCardProps) { ... }
|
||||
```
|
||||
|
||||
## 1. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import "@opal/components/tooltip.css";
|
||||
import {
|
||||
Disabled,
|
||||
Interactive,
|
||||
type InteractiveStatelessProps,
|
||||
} from "@opal/core";
|
||||
import { Interactive, type InteractiveStatelessProps } from "@opal/core";
|
||||
import type {
|
||||
ContainerSizeVariants,
|
||||
ExtremaSizeVariants,
|
||||
@@ -49,7 +45,7 @@ type ButtonProps = InteractiveStatelessProps &
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
|
||||
/** Wraps the button in a Disabled context. `false` overrides parent contexts. */
|
||||
/** Applies disabled styling and suppresses clicks. */
|
||||
disabled?: boolean;
|
||||
};
|
||||
|
||||
@@ -94,7 +90,11 @@ function Button({
|
||||
) : null;
|
||||
|
||||
const button = (
|
||||
<Interactive.Stateless type={type} {...interactiveProps}>
|
||||
<Interactive.Stateless
|
||||
type={type}
|
||||
disabled={disabled}
|
||||
{...interactiveProps}
|
||||
>
|
||||
<Interactive.Container
|
||||
type={type}
|
||||
border={interactiveProps.prominence === "secondary"}
|
||||
@@ -118,28 +118,24 @@ function Button({
|
||||
</Interactive.Stateless>
|
||||
);
|
||||
|
||||
const result = tooltip ? (
|
||||
<TooltipPrimitive.Root>
|
||||
<TooltipPrimitive.Trigger asChild>{button}</TooltipPrimitive.Trigger>
|
||||
<TooltipPrimitive.Portal>
|
||||
<TooltipPrimitive.Content
|
||||
className="opal-tooltip"
|
||||
side={tooltipSide}
|
||||
sideOffset={4}
|
||||
>
|
||||
{tooltip}
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
) : (
|
||||
button
|
||||
);
|
||||
|
||||
if (disabled != null) {
|
||||
return <Disabled disabled={disabled}>{result}</Disabled>;
|
||||
if (tooltip) {
|
||||
return (
|
||||
<TooltipPrimitive.Root>
|
||||
<TooltipPrimitive.Trigger asChild>{button}</TooltipPrimitive.Trigger>
|
||||
<TooltipPrimitive.Portal>
|
||||
<TooltipPrimitive.Content
|
||||
className="opal-tooltip"
|
||||
side={tooltipSide}
|
||||
sideOffset={4}
|
||||
>
|
||||
{tooltip}
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
);
|
||||
}
|
||||
|
||||
return result;
|
||||
return button;
|
||||
}
|
||||
|
||||
export { Button, type ButtonProps };
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import {
|
||||
Interactive,
|
||||
useDisabled,
|
||||
type InteractiveStatefulProps,
|
||||
type InteractiveStatefulInteraction,
|
||||
} from "@opal/core";
|
||||
@@ -74,6 +73,9 @@ type OpenButtonProps = Omit<InteractiveStatefulProps, "variant"> & {
|
||||
|
||||
/** Override the default rounding derived from `size`. */
|
||||
roundingVariant?: InteractiveContainerRoundingVariant;
|
||||
|
||||
/** Applies disabled styling and suppresses clicks. */
|
||||
disabled?: boolean;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -92,10 +94,9 @@ function OpenButton({
|
||||
roundingVariant: roundingVariantOverride,
|
||||
interaction,
|
||||
variant = "select-heavy",
|
||||
disabled,
|
||||
...statefulProps
|
||||
}: OpenButtonProps) {
|
||||
const { isDisabled } = useDisabled();
|
||||
|
||||
// Derive open state: explicit prop → Radix data-state (injected via Slot chain)
|
||||
const dataState = (statefulProps as Record<string, unknown>)["data-state"] as
|
||||
| string
|
||||
@@ -119,6 +120,7 @@ function OpenButton({
|
||||
<Interactive.Stateful
|
||||
variant={variant}
|
||||
interaction={resolvedInteraction}
|
||||
disabled={disabled}
|
||||
{...statefulProps}
|
||||
>
|
||||
<Interactive.Container
|
||||
@@ -168,7 +170,7 @@ function OpenButton({
|
||||
);
|
||||
|
||||
const resolvedTooltip =
|
||||
tooltip ?? (foldable && isDisabled && children ? children : undefined);
|
||||
tooltip ?? (foldable && disabled && children ? children : undefined);
|
||||
|
||||
if (!resolvedTooltip) return button;
|
||||
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import "@opal/components/buttons/select-button/styles.css";
|
||||
import {
|
||||
Interactive,
|
||||
useDisabled,
|
||||
type InteractiveStatefulProps,
|
||||
} from "@opal/core";
|
||||
import { Interactive, type InteractiveStatefulProps } from "@opal/core";
|
||||
import type {
|
||||
ContainerSizeVariants,
|
||||
ExtremaSizeVariants,
|
||||
@@ -64,6 +60,9 @@ type SelectButtonProps = InteractiveStatefulProps &
|
||||
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
|
||||
/** Applies disabled styling and suppresses clicks. */
|
||||
disabled?: boolean;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -80,9 +79,9 @@ function SelectButton({
|
||||
width,
|
||||
tooltip,
|
||||
tooltipSide = "top",
|
||||
disabled,
|
||||
...statefulProps
|
||||
}: SelectButtonProps) {
|
||||
const { isDisabled } = useDisabled();
|
||||
const isLarge = size === "lg";
|
||||
|
||||
const labelEl = children ? (
|
||||
@@ -96,7 +95,7 @@ function SelectButton({
|
||||
) : null;
|
||||
|
||||
const button = (
|
||||
<Interactive.Stateful {...statefulProps}>
|
||||
<Interactive.Stateful disabled={disabled} {...statefulProps}>
|
||||
<Interactive.Container
|
||||
type={type}
|
||||
heightVariant={size}
|
||||
@@ -128,7 +127,7 @@ function SelectButton({
|
||||
);
|
||||
|
||||
const resolvedTooltip =
|
||||
tooltip ?? (foldable && isDisabled && children ? children : undefined);
|
||||
tooltip ?? (foldable && disabled && children ? children : undefined);
|
||||
|
||||
if (!resolvedTooltip) return button;
|
||||
|
||||
|
||||
59
web/lib/opal/src/components/buttons/sidebar-tab/README.md
Normal file
59
web/lib/opal/src/components/buttons/sidebar-tab/README.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# SidebarTab
|
||||
|
||||
**Import:** `import { SidebarTab, type SidebarTabProps } from "@opal/components";`
|
||||
|
||||
A sidebar navigation tab built on `Interactive.Stateful` > `Interactive.Container`. Designed for admin and app sidebars.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
div.relative
|
||||
└─ Interactive.Stateful <- variant (sidebar-heavy | sidebar-light), state, disabled
|
||||
└─ Interactive.Container <- rounding, height, width
|
||||
├─ Link? (absolute overlay for client-side navigation)
|
||||
├─ rightChildren? (absolute, above Link for inline actions)
|
||||
└─ ContentAction (icon + title + truncation spacer)
|
||||
```
|
||||
|
||||
- **`sidebar-heavy`** (default) — muted when unselected (text-03/text-02), bold when selected (text-04/text-03)
|
||||
- **`sidebar-light`** (via `lowlight`) — uniformly muted across all states (text-02/text-02)
|
||||
- **Disabled** — both variants use text-02 foreground, transparent background, no hover/active states
|
||||
- **Navigation** uses an absolutely positioned `<Link>` overlay rather than `href` on the Interactive element, so `rightChildren` can sit above it with `pointer-events-auto`.
|
||||
|
||||
## Props
|
||||
|
||||
| Prop | Type | Default | Description |
|
||||
|------|------|---------|-------------|
|
||||
| `icon` | `IconFunctionComponent` | — | Left icon |
|
||||
| `children` | `ReactNode` | — | Label text or custom content |
|
||||
| `selected` | `boolean` | `false` | Active/selected state |
|
||||
| `lowlight` | `boolean` | `false` | Uses muted `sidebar-light` variant |
|
||||
| `disabled` | `boolean` | `false` | Disables the tab |
|
||||
| `folded` | `boolean` | `false` | Collapses label, shows tooltip on hover |
|
||||
| `nested` | `boolean` | `false` | Renders spacer instead of icon for indented items |
|
||||
| `href` | `string` | — | Client-side navigation URL |
|
||||
| `onClick` | `MouseEventHandler` | — | Click handler |
|
||||
| `type` | `ButtonType` | — | HTML button type |
|
||||
| `rightChildren` | `ReactNode` | — | Actions rendered on the right side |
|
||||
|
||||
## Usage
|
||||
|
||||
```tsx
|
||||
import { SidebarTab } from "@opal/components";
|
||||
import { SvgSettings, SvgLock } from "@opal/icons";
|
||||
|
||||
// Active tab
|
||||
<SidebarTab icon={SvgSettings} href="/admin/settings" selected>
|
||||
Settings
|
||||
</SidebarTab>
|
||||
|
||||
// Disabled enterprise-only tab
|
||||
<SidebarTab icon={SvgLock} disabled>
|
||||
Groups
|
||||
</SidebarTab>
|
||||
|
||||
// Folded sidebar (icon only, tooltip on hover)
|
||||
<SidebarTab icon={SvgSettings} href="/admin/settings" folded>
|
||||
Settings
|
||||
</SidebarTab>
|
||||
```
|
||||
@@ -0,0 +1,90 @@
|
||||
import React from "react";
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { SidebarTab } from "@opal/components/buttons/sidebar-tab/components";
|
||||
import { SvgSettings, SvgUsers, SvgLock, SvgArrowUpCircle } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgTrash } from "@opal/icons";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
|
||||
const meta: Meta<typeof SidebarTab> = {
|
||||
title: "opal/components/SidebarTab",
|
||||
component: SidebarTab,
|
||||
tags: ["autodocs"],
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<TooltipPrimitive.Provider>
|
||||
<div style={{ width: 260, background: "var(--background-neutral-01)" }}>
|
||||
<Story />
|
||||
</div>
|
||||
</TooltipPrimitive.Provider>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof SidebarTab>;
|
||||
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
},
|
||||
};
|
||||
|
||||
export const Selected: Story = {
|
||||
args: {
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
selected: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const Lowlight: Story = {
|
||||
args: {
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
lowlight: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const Disabled: Story = {
|
||||
args: {
|
||||
icon: SvgLock,
|
||||
children: "Enterprise Only",
|
||||
disabled: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const WithRightChildren: Story = {
|
||||
args: {
|
||||
icon: SvgUsers,
|
||||
children: "Users",
|
||||
rightChildren: (
|
||||
<Button
|
||||
icon={SvgTrash}
|
||||
size="xs"
|
||||
prominence="tertiary"
|
||||
variant="danger"
|
||||
/>
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
export const SidebarExample: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col">
|
||||
<SidebarTab icon={SvgSettings} selected>
|
||||
LLM Models
|
||||
</SidebarTab>
|
||||
<SidebarTab icon={SvgSettings}>Web Search</SidebarTab>
|
||||
<SidebarTab icon={SvgUsers}>Users</SidebarTab>
|
||||
<SidebarTab icon={SvgLock} disabled>
|
||||
Groups
|
||||
</SidebarTab>
|
||||
<SidebarTab icon={SvgLock} disabled>
|
||||
SCIM
|
||||
</SidebarTab>
|
||||
<SidebarTab icon={SvgArrowUpCircle}>Upgrade Plan</SidebarTab>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
159
web/lib/opal/src/components/buttons/sidebar-tab/components.tsx
Normal file
159
web/lib/opal/src/components/buttons/sidebar-tab/components.tsx
Normal file
@@ -0,0 +1,159 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import type { ButtonType, IconFunctionComponent } from "@opal/types";
|
||||
import type { Route } from "next";
|
||||
import { Interactive } from "@opal/core";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { Text } from "@opal/components";
|
||||
import Link from "next/link";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
import "@opal/components/tooltip.css";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface SidebarTabProps {
|
||||
/** Collapses the label, showing only the icon. */
|
||||
folded?: boolean;
|
||||
|
||||
/** Marks this tab as the currently active/selected item. */
|
||||
selected?: boolean;
|
||||
|
||||
/** Uses the muted `sidebar-light` variant instead of `sidebar-heavy`. */
|
||||
lowlight?: boolean;
|
||||
|
||||
/** Renders an empty spacer in place of the icon for nested items. */
|
||||
nested?: boolean;
|
||||
|
||||
/** Disables the tab — applies muted colors and suppresses clicks. */
|
||||
disabled?: boolean;
|
||||
|
||||
onClick?: React.MouseEventHandler<HTMLElement>;
|
||||
href?: string;
|
||||
type?: ButtonType;
|
||||
icon?: IconFunctionComponent;
|
||||
children?: React.ReactNode;
|
||||
|
||||
/** Content rendered on the right side (e.g. action buttons). */
|
||||
rightChildren?: React.ReactNode;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SidebarTab
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Sidebar navigation tab built on `Interactive.Stateful` > `Interactive.Container`.
|
||||
*
|
||||
* Uses `sidebar-heavy` (default) or `sidebar-light` (when `lowlight`) variants
|
||||
* for color styling. Supports an overlay `Link` for client-side navigation,
|
||||
* `rightChildren` for inline actions, and folded mode with an auto-tooltip.
|
||||
*/
|
||||
function SidebarTab({
|
||||
folded,
|
||||
selected,
|
||||
lowlight,
|
||||
nested,
|
||||
disabled,
|
||||
|
||||
onClick,
|
||||
href,
|
||||
type,
|
||||
icon,
|
||||
rightChildren,
|
||||
children,
|
||||
}: SidebarTabProps) {
|
||||
const Icon =
|
||||
icon ??
|
||||
(nested
|
||||
? ((() => (
|
||||
<div className="w-6" aria-hidden="true" />
|
||||
)) as IconFunctionComponent)
|
||||
: null);
|
||||
|
||||
// The `rightChildren` node is absolutely positioned to sit on top of the
|
||||
// overlay Link. A zero-width spacer reserves truncation space for the title.
|
||||
const truncationSpacer = rightChildren && (
|
||||
<div className="w-0 group-hover/SidebarTab:w-6" />
|
||||
);
|
||||
|
||||
const content = (
|
||||
<div className="relative">
|
||||
<Interactive.Stateful
|
||||
variant={lowlight ? "sidebar-light" : "sidebar-heavy"}
|
||||
state={selected ? "selected" : "empty"}
|
||||
disabled={disabled}
|
||||
onClick={onClick}
|
||||
type="button"
|
||||
group="group/SidebarTab"
|
||||
>
|
||||
<Interactive.Container
|
||||
roundingVariant="sm"
|
||||
heightVariant="lg"
|
||||
widthVariant="full"
|
||||
type={type}
|
||||
>
|
||||
{href && !disabled && (
|
||||
<Link
|
||||
href={href as Route}
|
||||
scroll={false}
|
||||
className="absolute z-[99] inset-0 rounded-08"
|
||||
tabIndex={-1}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!folded && rightChildren && (
|
||||
<div className="absolute z-[100] right-1.5 top-0 bottom-0 flex flex-col justify-center items-center pointer-events-auto">
|
||||
{rightChildren}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{typeof children === "string" ? (
|
||||
<ContentAction
|
||||
icon={Icon ?? undefined}
|
||||
title={folded ? "" : children}
|
||||
sizePreset="main-ui"
|
||||
variant="body"
|
||||
widthVariant="full"
|
||||
paddingVariant="fit"
|
||||
rightChildren={truncationSpacer}
|
||||
/>
|
||||
) : (
|
||||
<div className="flex flex-row items-center gap-2 flex-1">
|
||||
{Icon && (
|
||||
<div className="flex items-center justify-center p-0.5">
|
||||
<Icon className="h-[1rem] w-[1rem] text-text-03" />
|
||||
</div>
|
||||
)}
|
||||
{children}
|
||||
{truncationSpacer}
|
||||
</div>
|
||||
)}
|
||||
</Interactive.Container>
|
||||
</Interactive.Stateful>
|
||||
</div>
|
||||
);
|
||||
|
||||
if (typeof children !== "string") return content;
|
||||
if (folded) {
|
||||
return (
|
||||
<TooltipPrimitive.Root>
|
||||
<TooltipPrimitive.Trigger asChild>{content}</TooltipPrimitive.Trigger>
|
||||
<TooltipPrimitive.Portal>
|
||||
<TooltipPrimitive.Content
|
||||
className="opal-tooltip"
|
||||
side="right"
|
||||
sideOffset={4}
|
||||
>
|
||||
<Text>{children}</Text>
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
);
|
||||
}
|
||||
return content;
|
||||
}
|
||||
|
||||
export { SidebarTab, type SidebarTabProps };
|
||||
@@ -29,7 +29,7 @@ export const BackgroundVariants: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-4 w-96">
|
||||
{BACKGROUND_VARIANTS.map((bg) => (
|
||||
<Card key={bg} backgroundVariant={bg} borderVariant="solid">
|
||||
<Card key={bg} background={bg} border="solid">
|
||||
<p>backgroundVariant: {bg}</p>
|
||||
</Card>
|
||||
))}
|
||||
@@ -41,7 +41,7 @@ export const BorderVariants: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-4 w-96">
|
||||
{BORDER_VARIANTS.map((border) => (
|
||||
<Card key={border} borderVariant={border}>
|
||||
<Card key={border} border={border}>
|
||||
<p>borderVariant: {border}</p>
|
||||
</Card>
|
||||
))}
|
||||
@@ -53,7 +53,7 @@ export const PaddingVariants: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-4 w-96">
|
||||
{PADDING_VARIANTS.map((padding) => (
|
||||
<Card key={padding} paddingVariant={padding} borderVariant="solid">
|
||||
<Card key={padding} padding={padding} border="solid">
|
||||
<p>paddingVariant: {padding}</p>
|
||||
</Card>
|
||||
))}
|
||||
@@ -65,7 +65,7 @@ export const RoundingVariants: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-4 w-96">
|
||||
{ROUNDING_VARIANTS.map((rounding) => (
|
||||
<Card key={rounding} roundingVariant={rounding} borderVariant="solid">
|
||||
<Card key={rounding} rounding={rounding} border="solid">
|
||||
<p>roundingVariant: {rounding}</p>
|
||||
</Card>
|
||||
))}
|
||||
@@ -84,9 +84,9 @@ export const AllCombinations: Story = {
|
||||
BORDER_VARIANTS.map((border) => (
|
||||
<Card
|
||||
key={`${padding}-${bg}-${border}`}
|
||||
paddingVariant={padding}
|
||||
backgroundVariant={bg}
|
||||
borderVariant={border}
|
||||
padding={padding}
|
||||
background={bg}
|
||||
border={border}
|
||||
>
|
||||
<p className="text-xs">
|
||||
bg: {bg}, border: {border}
|
||||
|
||||
@@ -8,30 +8,30 @@ A plain container component with configurable background, border, padding, and r
|
||||
|
||||
Padding and rounding are controlled independently:
|
||||
|
||||
| `paddingVariant` | Class |
|
||||
|------------------|---------|
|
||||
| `"lg"` | `p-6` |
|
||||
| `"md"` | `p-4` |
|
||||
| `"sm"` | `p-2` |
|
||||
| `"xs"` | `p-1` |
|
||||
| `"2xs"` | `p-0.5` |
|
||||
| `"fit"` | `p-0` |
|
||||
| `padding` | Class |
|
||||
|-----------|---------|
|
||||
| `"lg"` | `p-6` |
|
||||
| `"md"` | `p-4` |
|
||||
| `"sm"` | `p-2` |
|
||||
| `"xs"` | `p-1` |
|
||||
| `"2xs"` | `p-0.5` |
|
||||
| `"fit"` | `p-0` |
|
||||
|
||||
| `roundingVariant` | Class |
|
||||
|-------------------|--------------|
|
||||
| `"xs"` | `rounded-04` |
|
||||
| `"sm"` | `rounded-08` |
|
||||
| `"md"` | `rounded-12` |
|
||||
| `"lg"` | `rounded-16` |
|
||||
| `rounding` | Class |
|
||||
|------------|--------------|
|
||||
| `"xs"` | `rounded-04` |
|
||||
| `"sm"` | `rounded-08` |
|
||||
| `"md"` | `rounded-12` |
|
||||
| `"lg"` | `rounded-16` |
|
||||
|
||||
## Props
|
||||
|
||||
| Prop | Type | Default | Description |
|
||||
|------|------|---------|-------------|
|
||||
| `paddingVariant` | `PaddingVariants` | `"sm"` | Padding preset |
|
||||
| `roundingVariant` | `RoundingVariants` | `"md"` | Border-radius preset |
|
||||
| `backgroundVariant` | `"none" \| "light" \| "heavy"` | `"light"` | Background fill intensity |
|
||||
| `borderVariant` | `"none" \| "dashed" \| "solid"` | `"none"` | Border style |
|
||||
| `padding` | `PaddingVariants` | `"sm"` | Padding preset |
|
||||
| `rounding` | `RoundingVariants` | `"md"` | Border-radius preset |
|
||||
| `background` | `"none" \| "light" \| "heavy"` | `"light"` | Background fill intensity |
|
||||
| `border` | `"none" \| "dashed" \| "solid"` | `"none"` | Border style |
|
||||
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
|
||||
| `children` | `React.ReactNode` | — | Card content |
|
||||
|
||||
@@ -47,17 +47,17 @@ import { Card } from "@opal/components";
|
||||
</Card>
|
||||
|
||||
// Large padding + rounding with solid border
|
||||
<Card paddingVariant="lg" roundingVariant="lg" borderVariant="solid">
|
||||
<Card padding="lg" rounding="lg" border="solid">
|
||||
<p>Spacious card</p>
|
||||
</Card>
|
||||
|
||||
// Compact card with solid border
|
||||
<Card paddingVariant="xs" roundingVariant="sm" borderVariant="solid">
|
||||
<Card padding="xs" rounding="sm" border="solid">
|
||||
<p>Compact card</p>
|
||||
</Card>
|
||||
|
||||
// Empty state card
|
||||
<Card backgroundVariant="none" borderVariant="dashed">
|
||||
<Card background="none" border="dashed">
|
||||
<p>No items yet</p>
|
||||
</Card>
|
||||
```
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import "@opal/components/cards/card/styles.css";
|
||||
import type { PaddingVariants, RoundingVariants } from "@opal/types";
|
||||
import { cardPaddingVariants, cardRoundingVariants } from "@opal/shared";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -22,9 +23,9 @@ type CardProps = {
|
||||
* | `"2xs"` | `p-0.5` |
|
||||
* | `"fit"` | `p-0` |
|
||||
*
|
||||
* @default "sm"
|
||||
* @default "md"
|
||||
*/
|
||||
paddingVariant?: PaddingVariants;
|
||||
padding?: PaddingVariants;
|
||||
|
||||
/**
|
||||
* Border-radius preset.
|
||||
@@ -38,7 +39,7 @@ type CardProps = {
|
||||
*
|
||||
* @default "md"
|
||||
*/
|
||||
roundingVariant?: RoundingVariants;
|
||||
rounding?: RoundingVariants;
|
||||
|
||||
/**
|
||||
* Background fill intensity.
|
||||
@@ -48,7 +49,7 @@ type CardProps = {
|
||||
*
|
||||
* @default "light"
|
||||
*/
|
||||
backgroundVariant?: BackgroundVariant;
|
||||
background?: BackgroundVariant;
|
||||
|
||||
/**
|
||||
* Border style.
|
||||
@@ -58,7 +59,7 @@ type CardProps = {
|
||||
*
|
||||
* @default "none"
|
||||
*/
|
||||
borderVariant?: BorderVariant;
|
||||
border?: BorderVariant;
|
||||
|
||||
/** Ref forwarded to the root `<div>`. */
|
||||
ref?: React.Ref<HTMLDivElement>;
|
||||
@@ -66,47 +67,27 @@ type CardProps = {
|
||||
children?: React.ReactNode;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mappings
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const paddingForVariant: Record<PaddingVariants, string> = {
|
||||
lg: "p-6",
|
||||
md: "p-4",
|
||||
sm: "p-2",
|
||||
xs: "p-1",
|
||||
"2xs": "p-0.5",
|
||||
fit: "p-0",
|
||||
};
|
||||
|
||||
const roundingForVariant: Record<RoundingVariants, string> = {
|
||||
lg: "rounded-16",
|
||||
md: "rounded-12",
|
||||
sm: "rounded-08",
|
||||
xs: "rounded-04",
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Card
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function Card({
|
||||
paddingVariant = "sm",
|
||||
roundingVariant = "md",
|
||||
backgroundVariant = "light",
|
||||
borderVariant = "none",
|
||||
padding: paddingProp = "md",
|
||||
rounding: roundingProp = "md",
|
||||
background = "light",
|
||||
border = "none",
|
||||
ref,
|
||||
children,
|
||||
}: CardProps) {
|
||||
const padding = paddingForVariant[paddingVariant];
|
||||
const rounding = roundingForVariant[roundingVariant];
|
||||
const padding = cardPaddingVariants[paddingProp];
|
||||
const rounding = cardRoundingVariants[roundingProp];
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn("opal-card", padding, rounding)}
|
||||
data-background={backgroundVariant}
|
||||
data-border={borderVariant}
|
||||
data-background={background}
|
||||
data-border={border}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { EmptyMessageCard } from "@opal/components";
|
||||
import { SvgSparkle, SvgUsers } from "@opal/icons";
|
||||
|
||||
const SIZE_VARIANTS = ["lg", "md", "sm", "xs", "2xs", "fit"] as const;
|
||||
const PADDING_VARIANTS = ["fit", "2xs", "xs", "sm", "md", "lg"] as const;
|
||||
|
||||
const meta: Meta<typeof EmptyMessageCard> = {
|
||||
title: "opal/components/EmptyMessageCard",
|
||||
@@ -26,14 +26,14 @@ export const WithCustomIcon: Story = {
|
||||
},
|
||||
};
|
||||
|
||||
export const SizeVariants: Story = {
|
||||
export const PaddingVariants: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-4 w-96">
|
||||
{SIZE_VARIANTS.map((size) => (
|
||||
{PADDING_VARIANTS.map((padding) => (
|
||||
<EmptyMessageCard
|
||||
key={size}
|
||||
sizeVariant={size}
|
||||
title={`sizeVariant: ${size}`}
|
||||
key={padding}
|
||||
padding={padding}
|
||||
title={`padding: ${padding}`}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
@@ -6,12 +6,12 @@ A pre-configured Card for empty states. Renders a transparent card with a dashed
|
||||
|
||||
## Props
|
||||
|
||||
| Prop | Type | Default | Description |
|
||||
| ----------------- | --------------------------- | ---------- | ------------------------------------------------ |
|
||||
| `icon` | `IconFunctionComponent` | `SvgEmpty` | Icon displayed alongside the title |
|
||||
| `title` | `string` | — | Primary message text (required) |
|
||||
| `paddingVariant` | `PaddingVariants` | `"sm"` | Padding preset for the card |
|
||||
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
|
||||
| Prop | Type | Default | Description |
|
||||
| --------- | --------------------------- | ---------- | -------------------------------- |
|
||||
| `icon` | `IconFunctionComponent` | `SvgEmpty` | Icon displayed alongside the title |
|
||||
| `title` | `string` | — | Primary message text (required) |
|
||||
| `padding` | `PaddingVariants` | `"sm"` | Padding preset for the card |
|
||||
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -26,5 +26,5 @@ import { SvgSparkle, SvgFileText } from "@opal/icons";
|
||||
<EmptyMessageCard icon={SvgSparkle} title="No agents selected." />
|
||||
|
||||
// With custom padding
|
||||
<EmptyMessageCard paddingVariant="xs" icon={SvgFileText} title="No documents available." />
|
||||
<EmptyMessageCard padding="xs" icon={SvgFileText} title="No documents available." />
|
||||
```
|
||||
|
||||
@@ -14,8 +14,8 @@ type EmptyMessageCardProps = {
|
||||
/** Primary message text. */
|
||||
title: string;
|
||||
|
||||
/** Padding preset for the card. */
|
||||
paddingVariant?: PaddingVariants;
|
||||
/** Padding preset for the card. @default "md" */
|
||||
padding?: PaddingVariants;
|
||||
|
||||
/** Ref forwarded to the root Card div. */
|
||||
ref?: React.Ref<HTMLDivElement>;
|
||||
@@ -28,15 +28,16 @@ type EmptyMessageCardProps = {
|
||||
function EmptyMessageCard({
|
||||
icon = SvgEmpty,
|
||||
title,
|
||||
paddingVariant = "sm",
|
||||
padding = "md",
|
||||
ref,
|
||||
}: EmptyMessageCardProps) {
|
||||
return (
|
||||
<Card
|
||||
ref={ref}
|
||||
backgroundVariant="none"
|
||||
borderVariant="dashed"
|
||||
paddingVariant={paddingVariant}
|
||||
background="none"
|
||||
border="dashed"
|
||||
padding={padding}
|
||||
rounding="md"
|
||||
>
|
||||
<Content
|
||||
icon={icon}
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
**Import:** `import { SelectCard, type SelectCardProps } from "@opal/components";`
|
||||
|
||||
A stateful interactive card — the card counterpart to [`SelectButton`](../../buttons/select-button/README.md). Built on `Interactive.Stateful` (Slot) with a structural `<div>` that owns padding, rounding, border, and overflow.
|
||||
A stateful interactive card — the card counterpart to [`SelectButton`](../../buttons/select-button/README.md). Built on `Interactive.Stateful` (Slot) with a structural `<div>` that owns padding, rounding, border, and overflow. Always uses the `select-card` Interactive.Stateful variant internally.
|
||||
|
||||
## Relationship to Card
|
||||
|
||||
`Card` is a plain, non-interactive container. `SelectCard` adds stateful interactivity (hover, active, disabled, state-driven colors) by wrapping its root div with `Interactive.Stateful`. The relationship mirrors `Button` (stateless) vs `SelectButton` (stateful).
|
||||
`Card` is a plain, non-interactive container. `SelectCard` adds stateful interactivity (hover, active, disabled, state-driven colors) by wrapping its root div with `Interactive.Stateful`. Both share the same independent `padding` / `rounding` API.
|
||||
|
||||
## Relationship to SelectButton
|
||||
|
||||
@@ -18,15 +18,15 @@ Interactive.Stateful → structural element → content
|
||||
|
||||
The key differences:
|
||||
|
||||
- SelectCard renders a `<div>` (not `Interactive.Container`) — cards have their own rounding scale (one notch larger than buttons) and don't need Container's height/min-width.
|
||||
- SelectCard renders a `<div>` (not `Interactive.Container`) — cards have their own rounding scale and don't need Container's height/min-width.
|
||||
- SelectCard has no `foldable` prop — use `Interactive.Foldable` directly inside children.
|
||||
- SelectCard's children are fully composable — use `CardHeaderLayout`, `ContentAction`, `Content`, buttons, etc. inside.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Interactive.Stateful <- variant, state, interaction, disabled, onClick
|
||||
└─ div.opal-select-card <- padding, rounding, border, overflow
|
||||
Interactive.Stateful (variant="select-card") <- state, interaction, disabled, onClick
|
||||
└─ div.opal-select-card <- padding, rounding, border, overflow
|
||||
└─ children (composable)
|
||||
```
|
||||
|
||||
@@ -34,28 +34,36 @@ The `Interactive.Stateful` Slot merges onto the div, producing a single DOM elem
|
||||
|
||||
## Props
|
||||
|
||||
Inherits **all** props from `InteractiveStatefulProps` (variant, state, interaction, onClick, href, etc.) plus:
|
||||
Inherits **all** props from `InteractiveStatefulProps` (except `variant`, which is hardcoded to `select-card`) plus:
|
||||
|
||||
| Prop | Type | Default | Description |
|
||||
|---|---|---|---|
|
||||
| `sizeVariant` | `ContainerSizeVariants` | `"lg"` | Controls padding and border-radius |
|
||||
| `padding` | `PaddingVariants` | `"sm"` | Padding preset |
|
||||
| `rounding` | `RoundingVariants` | `"lg"` | Border-radius preset |
|
||||
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
|
||||
| `children` | `React.ReactNode` | — | Card content |
|
||||
|
||||
### Padding scale
|
||||
|
||||
| `padding` | Class |
|
||||
|-----------|---------|
|
||||
| `"lg"` | `p-6` |
|
||||
| `"md"` | `p-4` |
|
||||
| `"sm"` | `p-2` |
|
||||
| `"xs"` | `p-1` |
|
||||
| `"2xs"` | `p-0.5` |
|
||||
| `"fit"` | `p-0` |
|
||||
|
||||
### Rounding scale
|
||||
|
||||
Cards use a bumped-up rounding scale compared to buttons:
|
||||
| `rounding` | Class |
|
||||
|------------|--------------|
|
||||
| `"xs"` | `rounded-04` |
|
||||
| `"sm"` | `rounded-08` |
|
||||
| `"md"` | `rounded-12` |
|
||||
| `"lg"` | `rounded-16` |
|
||||
|
||||
| Size | Rounding | Effective radius |
|
||||
|---|---|---|
|
||||
| `lg` | `rounded-16` | 1rem (16px) |
|
||||
| `md`–`sm` | `rounded-12` | 0.75rem (12px) |
|
||||
| `xs`–`2xs` | `rounded-08` | 0.5rem (8px) |
|
||||
| `fit` | `rounded-16` | 1rem (16px) |
|
||||
|
||||
### Recommended variant: `select-card`
|
||||
|
||||
The `select-card` Interactive.Stateful variant is specifically designed for cards. Unlike `select-heavy` (which only changes foreground color between empty and filled), `select-card` gives the filled state a visible background — important on larger surfaces where background carries more of the visual distinction.
|
||||
### State colors (`select-card` variant)
|
||||
|
||||
| State | Rest background | Rest foreground |
|
||||
|---|---|---|
|
||||
@@ -82,7 +90,7 @@ All background and foreground colors come from the Interactive.Stateful CSS, not
|
||||
import { SelectCard } from "@opal/components";
|
||||
import { CardHeaderLayout } from "@opal/layouts";
|
||||
|
||||
<SelectCard variant="select-card" state="selected" onClick={handleClick}>
|
||||
<SelectCard state="selected" onClick={handleClick}>
|
||||
<CardHeaderLayout
|
||||
icon={SvgGlobe}
|
||||
title="Google"
|
||||
@@ -100,7 +108,7 @@ import { CardHeaderLayout } from "@opal/layouts";
|
||||
### Disconnected state (clickable)
|
||||
|
||||
```tsx
|
||||
<SelectCard variant="select-card" state="empty" onClick={handleConnect}>
|
||||
<SelectCard state="empty" onClick={handleConnect}>
|
||||
<CardHeaderLayout
|
||||
icon={SvgCloud}
|
||||
title="OpenAI"
|
||||
@@ -115,7 +123,7 @@ import { CardHeaderLayout } from "@opal/layouts";
|
||||
### With foldable hover-reveal
|
||||
|
||||
```tsx
|
||||
<SelectCard variant="select-card" state="filled">
|
||||
<SelectCard state="filled">
|
||||
<CardHeaderLayout
|
||||
icon={SvgCloud}
|
||||
title="OpenAI"
|
||||
|
||||
@@ -21,7 +21,8 @@ const withTooltipProvider: Decorator = (Story) => (
|
||||
);
|
||||
|
||||
const STATES = ["empty", "filled", "selected"] as const;
|
||||
const SIZE_VARIANTS = ["lg", "md", "sm", "xs", "2xs", "fit"] as const;
|
||||
const PADDING_VARIANTS = ["fit", "2xs", "xs", "sm", "md", "lg"] as const;
|
||||
const ROUNDING_VARIANTS = ["xs", "sm", "md", "lg"] as const;
|
||||
|
||||
const meta = {
|
||||
title: "opal/components/SelectCard",
|
||||
@@ -44,7 +45,7 @@ type Story = StoryObj<typeof meta>;
|
||||
export const Default: Story = {
|
||||
render: () => (
|
||||
<div className="w-96">
|
||||
<SelectCard variant="select-card" state="empty">
|
||||
<SelectCard state="empty">
|
||||
<div className="p-2">
|
||||
<Content
|
||||
sizePreset="main-ui"
|
||||
@@ -63,7 +64,7 @@ export const AllStates: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-4 w-96">
|
||||
{STATES.map((state) => (
|
||||
<SelectCard key={state} variant="select-card" state={state}>
|
||||
<SelectCard key={state} state={state}>
|
||||
<div className="p-2">
|
||||
<Content
|
||||
sizePreset="main-ui"
|
||||
@@ -82,11 +83,7 @@ export const AllStates: Story = {
|
||||
export const Clickable: Story = {
|
||||
render: () => (
|
||||
<div className="w-96">
|
||||
<SelectCard
|
||||
variant="select-card"
|
||||
state="empty"
|
||||
onClick={() => alert("Card clicked")}
|
||||
>
|
||||
<SelectCard state="empty" onClick={() => alert("Card clicked")}>
|
||||
<div className="p-2">
|
||||
<Content
|
||||
sizePreset="main-ui"
|
||||
@@ -105,7 +102,7 @@ export const WithActions: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-4 w-[28rem]">
|
||||
{/* Disconnected */}
|
||||
<SelectCard variant="select-card" state="empty" onClick={() => {}}>
|
||||
<SelectCard state="empty" onClick={() => {}}>
|
||||
<div className="flex flex-row items-stretch w-full">
|
||||
<div className="flex-1 p-2">
|
||||
<Content
|
||||
@@ -125,7 +122,7 @@ export const WithActions: Story = {
|
||||
</SelectCard>
|
||||
|
||||
{/* Connected with foldable */}
|
||||
<SelectCard variant="select-card" state="filled">
|
||||
<SelectCard state="filled">
|
||||
<div className="flex flex-row items-stretch w-full">
|
||||
<div className="flex-1 p-2">
|
||||
<Content
|
||||
@@ -163,7 +160,7 @@ export const WithActions: Story = {
|
||||
</SelectCard>
|
||||
|
||||
{/* Selected */}
|
||||
<SelectCard variant="select-card" state="selected">
|
||||
<SelectCard state="selected">
|
||||
<div className="flex flex-row items-stretch w-full">
|
||||
<div className="flex-1 p-2">
|
||||
<Content
|
||||
@@ -203,22 +200,17 @@ export const WithActions: Story = {
|
||||
),
|
||||
};
|
||||
|
||||
export const SizeVariants: Story = {
|
||||
export const PaddingVariants: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-4 w-96">
|
||||
{SIZE_VARIANTS.map((size) => (
|
||||
<SelectCard
|
||||
key={size}
|
||||
variant="select-card"
|
||||
state="filled"
|
||||
sizeVariant={size}
|
||||
>
|
||||
{PADDING_VARIANTS.map((padding) => (
|
||||
<SelectCard key={padding} state="filled" padding={padding}>
|
||||
<Content
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
icon={SvgGlobe}
|
||||
title={`sizeVariant: ${size}`}
|
||||
description="Shows padding and rounding differences."
|
||||
title={`paddingVariant: ${padding}`}
|
||||
description="Shows padding differences."
|
||||
/>
|
||||
</SelectCard>
|
||||
))}
|
||||
@@ -226,20 +218,18 @@ export const SizeVariants: Story = {
|
||||
),
|
||||
};
|
||||
|
||||
export const SelectHeavyVariant: Story = {
|
||||
export const RoundingVariants: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-4 w-96">
|
||||
{STATES.map((state) => (
|
||||
<SelectCard key={state} variant="select-heavy" state={state}>
|
||||
<div className="p-2">
|
||||
<Content
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
icon={SvgGlobe}
|
||||
title={`select-heavy / ${state}`}
|
||||
description="For comparison with select-card variant."
|
||||
/>
|
||||
</div>
|
||||
{ROUNDING_VARIANTS.map((rounding) => (
|
||||
<SelectCard key={rounding} state="filled" rounding={rounding}>
|
||||
<Content
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
icon={SvgGlobe}
|
||||
title={`roundingVariant: ${rounding}`}
|
||||
description="Shows rounding differences."
|
||||
/>
|
||||
</SelectCard>
|
||||
))}
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import "@opal/components/cards/select-card/styles.css";
|
||||
import type { ContainerSizeVariants } from "@opal/types";
|
||||
import { containerSizeVariants } from "@opal/shared";
|
||||
import type { PaddingVariants, RoundingVariants } from "@opal/types";
|
||||
import { cardPaddingVariants, cardRoundingVariants } from "@opal/shared";
|
||||
import { cn } from "@opal/utils";
|
||||
import { Interactive, type InteractiveStatefulProps } from "@opal/core";
|
||||
|
||||
@@ -8,23 +8,36 @@ import { Interactive, type InteractiveStatefulProps } from "@opal/core";
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type SelectCardProps = InteractiveStatefulProps & {
|
||||
type SelectCardProps = Omit<InteractiveStatefulProps, "variant"> & {
|
||||
/**
|
||||
* Size preset — controls padding and border-radius.
|
||||
* Padding preset.
|
||||
*
|
||||
* Padding comes from the shared size scale. Rounding follows the same
|
||||
* mapping as `Card` / `Button` / `Interactive.Container`:
|
||||
* | Value | Class |
|
||||
* |---------|---------|
|
||||
* | `"lg"` | `p-6` |
|
||||
* | `"md"` | `p-4` |
|
||||
* | `"sm"` | `p-2` |
|
||||
* | `"xs"` | `p-1` |
|
||||
* | `"2xs"` | `p-0.5` |
|
||||
* | `"fit"` | `p-0` |
|
||||
*
|
||||
* | Size | Rounding |
|
||||
* |------------|--------------|
|
||||
* | `lg` | `rounded-16` |
|
||||
* | `md`–`sm` | `rounded-12` |
|
||||
* | `xs`–`2xs` | `rounded-08` |
|
||||
* | `fit` | `rounded-16` |
|
||||
*
|
||||
* @default "lg"
|
||||
* @default "md"
|
||||
*/
|
||||
sizeVariant?: ContainerSizeVariants;
|
||||
padding?: PaddingVariants;
|
||||
|
||||
/**
|
||||
* Border-radius preset.
|
||||
*
|
||||
* | Value | Class |
|
||||
* |--------|--------------|
|
||||
* | `"xs"` | `rounded-04` |
|
||||
* | `"sm"` | `rounded-08` |
|
||||
* | `"md"` | `rounded-12` |
|
||||
* | `"lg"` | `rounded-16` |
|
||||
*
|
||||
* @default "md"
|
||||
*/
|
||||
rounding?: RoundingVariants;
|
||||
|
||||
/** Ref forwarded to the root `<div>`. */
|
||||
ref?: React.Ref<HTMLDivElement>;
|
||||
@@ -32,19 +45,6 @@ type SelectCardProps = InteractiveStatefulProps & {
|
||||
children?: React.ReactNode;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Rounding
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const roundingForSize: Record<ContainerSizeVariants, string> = {
|
||||
lg: "rounded-16",
|
||||
md: "rounded-12",
|
||||
sm: "rounded-12",
|
||||
xs: "rounded-08",
|
||||
"2xs": "rounded-08",
|
||||
fit: "rounded-16",
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SelectCard
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -61,7 +61,7 @@ const roundingForSize: Record<ContainerSizeVariants, string> = {
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <SelectCard variant="select-card" state="selected" onClick={handleClick}>
|
||||
* <SelectCard state="selected" onClick={handleClick}>
|
||||
* <ContentAction
|
||||
* icon={SvgGlobe}
|
||||
* title="Google"
|
||||
@@ -72,16 +72,17 @@ const roundingForSize: Record<ContainerSizeVariants, string> = {
|
||||
* ```
|
||||
*/
|
||||
function SelectCard({
|
||||
sizeVariant = "lg",
|
||||
padding: paddingProp = "md",
|
||||
rounding: roundingProp = "md",
|
||||
ref,
|
||||
children,
|
||||
...statefulProps
|
||||
}: SelectCardProps) {
|
||||
const { padding } = containerSizeVariants[sizeVariant];
|
||||
const rounding = roundingForSize[sizeVariant];
|
||||
const padding = cardPaddingVariants[paddingProp];
|
||||
const rounding = cardRoundingVariants[roundingProp];
|
||||
|
||||
return (
|
||||
<Interactive.Stateful {...statefulProps}>
|
||||
<Interactive.Stateful {...statefulProps} variant="select-card">
|
||||
<div ref={ref} className={cn("opal-select-card", padding, rounding)}>
|
||||
{children}
|
||||
</div>
|
||||
|
||||
@@ -33,6 +33,12 @@ export {
|
||||
type LineItemButtonProps,
|
||||
} from "@opal/components/buttons/line-item-button/components";
|
||||
|
||||
/* SidebarTab */
|
||||
export {
|
||||
SidebarTab,
|
||||
type SidebarTabProps,
|
||||
} from "@opal/components/buttons/sidebar-tab/components";
|
||||
|
||||
/* Text */
|
||||
export {
|
||||
Text,
|
||||
|
||||
@@ -1,33 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import "@opal/core/disabled/styles.css";
|
||||
import React, { createContext, useContext } from "react";
|
||||
import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface DisabledContextValue {
|
||||
isDisabled: boolean;
|
||||
allowClick: boolean;
|
||||
}
|
||||
|
||||
const DisabledContext = createContext<DisabledContextValue>({
|
||||
isDisabled: false,
|
||||
allowClick: false,
|
||||
});
|
||||
|
||||
/**
|
||||
* Returns the current disabled state from the nearest `<Disabled>` ancestor.
|
||||
*
|
||||
* Used internally by `Interactive.Stateless` and `Interactive.Stateful` to
|
||||
* derive `data-disabled` and `aria-disabled` attributes automatically.
|
||||
*/
|
||||
function useDisabled(): DisabledContextValue {
|
||||
return useContext(DisabledContext);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -56,8 +30,8 @@ interface DisabledProps extends React.HTMLAttributes<HTMLElement> {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Wrapper component that propagates disabled state via context and applies
|
||||
* baseline disabled CSS (opacity, cursor, pointer-events) to its child.
|
||||
* Wrapper component that applies baseline disabled CSS (opacity, cursor,
|
||||
* pointer-events) to its child element.
|
||||
*
|
||||
* Uses Radix `Slot` — merges props onto the single child element without
|
||||
* adding any DOM node. Works correctly inside Radix `asChild` chains.
|
||||
@@ -65,7 +39,7 @@ interface DisabledProps extends React.HTMLAttributes<HTMLElement> {
|
||||
* @example
|
||||
* ```tsx
|
||||
* <Disabled disabled={!canSubmit}>
|
||||
* <Button onClick={handleSubmit}>Save</Button>
|
||||
* <div>...</div>
|
||||
* </Disabled>
|
||||
* ```
|
||||
*/
|
||||
@@ -77,20 +51,16 @@ function Disabled({
|
||||
...rest
|
||||
}: DisabledProps) {
|
||||
return (
|
||||
<DisabledContext.Provider
|
||||
value={{ isDisabled: !!disabled, allowClick: !!allowClick }}
|
||||
<Slot
|
||||
ref={ref}
|
||||
{...rest}
|
||||
aria-disabled={disabled || undefined}
|
||||
data-opal-disabled={disabled || undefined}
|
||||
data-allow-click={disabled && allowClick ? "" : undefined}
|
||||
>
|
||||
<Slot
|
||||
ref={ref}
|
||||
{...rest}
|
||||
aria-disabled={disabled || undefined}
|
||||
data-opal-disabled={disabled || undefined}
|
||||
data-allow-click={disabled && allowClick ? "" : undefined}
|
||||
>
|
||||
{children}
|
||||
</Slot>
|
||||
</DisabledContext.Provider>
|
||||
{children}
|
||||
</Slot>
|
||||
);
|
||||
}
|
||||
|
||||
export { Disabled, useDisabled, type DisabledProps, type DisabledContextValue };
|
||||
export { Disabled, type DisabledProps };
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
/* Disabled */
|
||||
export {
|
||||
Disabled,
|
||||
useDisabled,
|
||||
type DisabledProps,
|
||||
type DisabledContextValue,
|
||||
} from "@opal/core/disabled/components";
|
||||
export { Disabled, type DisabledProps } from "@opal/core/disabled/components";
|
||||
|
||||
/* Animations (formerly Hoverable) */
|
||||
export {
|
||||
|
||||
@@ -10,7 +10,6 @@ import {
|
||||
widthVariants,
|
||||
type ExtremaSizeVariants,
|
||||
} from "@opal/shared";
|
||||
import { useDisabled } from "@opal/core/disabled/components";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
@@ -102,7 +101,6 @@ function InteractiveContainer({
|
||||
widthVariant = "fit",
|
||||
...props
|
||||
}: InteractiveContainerProps) {
|
||||
const { allowClick } = useDisabled();
|
||||
const {
|
||||
className: slotClassName,
|
||||
style: slotStyle,
|
||||
@@ -148,8 +146,7 @@ function InteractiveContainer({
|
||||
if (type) {
|
||||
const ariaDisabled = (rest as Record<string, unknown>)["aria-disabled"];
|
||||
const nativeDisabled =
|
||||
(type === "submit" || !allowClick) &&
|
||||
(ariaDisabled === true || ariaDisabled === "true" || undefined);
|
||||
ariaDisabled === true || ariaDisabled === "true" || undefined;
|
||||
return (
|
||||
<button
|
||||
ref={ref as React.Ref<HTMLButtonElement>}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user