mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-01 13:02:42 +00:00
Compare commits
11 Commits
seed-defau
...
multi-mode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bab95d8bf0 | ||
|
|
eb7bc74e1b | ||
|
|
29da0aefb5 | ||
|
|
6c86301c51 | ||
|
|
631146f48f | ||
|
|
f327278506 | ||
|
|
c7cc439862 | ||
|
|
3365a369e2 | ||
|
|
470bda3fb5 | ||
|
|
13f511e209 | ||
|
|
c5e8ba1eab |
6
.github/workflows/deployment.yml
vendored
6
.github/workflows/deployment.yml
vendored
@@ -704,9 +704,6 @@ jobs:
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
SENTRY_RELEASE=${{ github.sha }}
|
||||
secrets: |
|
||||
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
@@ -789,9 +786,6 @@ jobs:
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
SENTRY_RELEASE=${{ github.sha }}
|
||||
secrets: |
|
||||
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@2e2940618cb426dce2999631d543b53cdcfc8527
|
||||
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
"""backfill_account_type
|
||||
|
||||
Revision ID: 03d085c5c38d
|
||||
Revises: 977e834c1427
|
||||
Create Date: 2026-03-25 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03d085c5c38d"
|
||||
down_revision = "977e834c1427"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
_STANDARD = "STANDARD"
|
||||
_BOT = "BOT"
|
||||
_EXT_PERM_USER = "EXT_PERM_USER"
|
||||
_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
_ANONYMOUS = "ANONYMOUS"
|
||||
|
||||
# Well-known anonymous user UUID
|
||||
ANONYMOUS_USER_ID = "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
# Email pattern for API key virtual users
|
||||
API_KEY_EMAIL_PATTERN = r"API\_KEY\_\_%"
|
||||
|
||||
# Reflect the table structure for use in DML
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("email", sa.String),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Backfill account_type from role.
|
||||
# Order matters — most-specific matches first so the final catch-all
|
||||
# only touches rows that haven't been classified yet.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# 1a. API key virtual users → SERVICE_ACCOUNT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.email.ilike(API_KEY_EMAIL_PATTERN),
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_SERVICE_ACCOUNT)
|
||||
)
|
||||
|
||||
# 1b. Anonymous user → ANONYMOUS
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.id == ANONYMOUS_USER_ID,
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_ANONYMOUS)
|
||||
)
|
||||
|
||||
# 1c. SLACK_USER role → BOT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "SLACK_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_BOT)
|
||||
)
|
||||
|
||||
# 1d. EXT_PERM_USER role → EXT_PERM_USER
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "EXT_PERM_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_EXT_PERM_USER)
|
||||
)
|
||||
|
||||
# 1e. Everything else → STANDARD
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(user_table.c.account_type.is_(None))
|
||||
.values(account_type=_STANDARD)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Set account_type to NOT NULL now that every row is filled.
|
||||
# ------------------------------------------------------------------
|
||||
op.alter_column(
|
||||
"user",
|
||||
"account_type",
|
||||
nullable=False,
|
||||
server_default="STANDARD",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("user", "account_type", nullable=True, server_default=None)
|
||||
op.execute(sa.update(user_table).values(account_type=None))
|
||||
@@ -1,104 +0,0 @@
|
||||
"""add_effective_permissions
|
||||
|
||||
Adds a JSONB column `effective_permissions` to the user table to store
|
||||
directly granted permissions (e.g. ["admin"] or ["basic"]). Implied
|
||||
permissions are expanded at read time, not stored.
|
||||
|
||||
Backfill: joins user__user_group → permission_grant to collect each
|
||||
user's granted permissions into a JSON array. Users without group
|
||||
memberships keep the default [].
|
||||
|
||||
Revision ID: 503883791c39
|
||||
Revises: b4b7e1028dfd
|
||||
Create Date: 2026-03-30 14:49:22.261748
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "503883791c39"
|
||||
down_revision = "b4b7e1028dfd"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("effective_permissions", postgresql.JSONB),
|
||||
)
|
||||
|
||||
user_user_group = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_id", sa.Uuid),
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"effective_permissions",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'[]'::jsonb"),
|
||||
),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Deduplicated permissions per user
|
||||
deduped = (
|
||||
sa.select(
|
||||
user_user_group.c.user_id,
|
||||
permission_grant.c.permission,
|
||||
)
|
||||
.select_from(
|
||||
user_user_group.join(
|
||||
permission_grant,
|
||||
sa.and_(
|
||||
permission_grant.c.group_id == user_user_group.c.user_group_id,
|
||||
permission_grant.c.is_deleted == sa.false(),
|
||||
),
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
.subquery("deduped")
|
||||
)
|
||||
|
||||
# Aggregate into JSONB array per user (order is not guaranteed;
|
||||
# consumers read this as a set so ordering does not matter)
|
||||
perms_per_user = (
|
||||
sa.select(
|
||||
deduped.c.user_id,
|
||||
sa.func.jsonb_agg(
|
||||
deduped.c.permission,
|
||||
type_=postgresql.JSONB,
|
||||
).label("perms"),
|
||||
)
|
||||
.group_by(deduped.c.user_id)
|
||||
.subquery("sub")
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
user_table.update()
|
||||
.where(user_table.c.id == perms_per_user.c.user_id)
|
||||
.values(effective_permissions=perms_per_user.c.perms)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "effective_permissions")
|
||||
@@ -1,54 +0,0 @@
|
||||
"""csv to tabular chat file type
|
||||
|
||||
Revision ID: 8188861f4e92
|
||||
Revises: d8cdfee5df80
|
||||
Create Date: 2026-03-31 19:23:05.753184
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8188861f4e92"
|
||||
down_revision = "d8cdfee5df80"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET files = (
|
||||
SELECT jsonb_agg(
|
||||
CASE
|
||||
WHEN elem->>'type' = 'csv'
|
||||
THEN jsonb_set(elem, '{type}', '"tabular"')
|
||||
ELSE elem
|
||||
END
|
||||
)
|
||||
FROM jsonb_array_elements(files) AS elem
|
||||
)
|
||||
WHERE files::text LIKE '%"type": "csv"%'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET files = (
|
||||
SELECT jsonb_agg(
|
||||
CASE
|
||||
WHEN elem->>'type' = 'tabular'
|
||||
THEN jsonb_set(elem, '{type}', '"csv"')
|
||||
ELSE elem
|
||||
END
|
||||
)
|
||||
FROM jsonb_array_elements(files) AS elem
|
||||
)
|
||||
WHERE files::text LIKE '%"type": "tabular"%'
|
||||
"""
|
||||
)
|
||||
@@ -1,136 +0,0 @@
|
||||
"""seed_default_groups
|
||||
|
||||
Revision ID: 977e834c1427
|
||||
Revises: 8188861f4e92
|
||||
Create Date: 2026-03-25 14:59:41.313091
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "977e834c1427"
|
||||
down_revision = "8188861f4e92"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# (group_name, permission_value)
|
||||
DEFAULT_GROUPS = [
|
||||
("Admin", "admin"),
|
||||
("Basic", "basic"),
|
||||
]
|
||||
|
||||
CUSTOM_SUFFIX = "(Custom)"
|
||||
|
||||
MAX_RENAME_ATTEMPTS = 100
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_up_to_date", sa.Boolean),
|
||||
sa.column("is_up_for_deletion", sa.Boolean),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant_table = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def _find_available_name(conn: sa.engine.Connection, base: str) -> str:
|
||||
"""Return a name like 'Admin (Custom)' or 'Admin (Custom 2)' that is not taken."""
|
||||
candidate = f"{base} {CUSTOM_SUFFIX}"
|
||||
attempt = 1
|
||||
while attempt <= MAX_RENAME_ATTEMPTS:
|
||||
exists = conn.execute(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(user_group_table)
|
||||
.where(user_group_table.c.name == candidate)
|
||||
.limit(1)
|
||||
).fetchone()
|
||||
if exists is None:
|
||||
return candidate
|
||||
attempt += 1
|
||||
candidate = f"{base} (Custom {attempt})"
|
||||
raise RuntimeError(
|
||||
f"Could not find an available name for group '{base}' "
|
||||
f"after {MAX_RENAME_ATTEMPTS} attempts"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
for group_name, permission_value in DEFAULT_GROUPS:
|
||||
# Step 1: Rename ALL existing groups that clash with the canonical name.
|
||||
conflicting = conn.execute(
|
||||
sa.select(user_group_table.c.id, user_group_table.c.name).where(
|
||||
user_group_table.c.name == group_name
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for row_id, row_name in conflicting:
|
||||
new_name = _find_available_name(conn, row_name)
|
||||
op.execute(
|
||||
sa.update(user_group_table)
|
||||
.where(user_group_table.c.id == row_id)
|
||||
.values(name=new_name, is_up_to_date=False)
|
||||
)
|
||||
|
||||
# Step 2: Create a fresh default group.
|
||||
result = conn.execute(
|
||||
user_group_table.insert()
|
||||
.values(
|
||||
name=group_name,
|
||||
is_up_to_date=True,
|
||||
is_up_for_deletion=False,
|
||||
is_default=True,
|
||||
)
|
||||
.returning(user_group_table.c.id)
|
||||
).fetchone()
|
||||
assert result is not None
|
||||
group_id = result[0]
|
||||
|
||||
# Step 3: Upsert permission grant.
|
||||
op.execute(
|
||||
pg_insert(permission_grant_table)
|
||||
.values(
|
||||
group_id=group_id,
|
||||
permission=permission_value,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["group_id", "permission"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the default groups created by this migration.
|
||||
# First remove user-group memberships that reference default groups
|
||||
# to avoid FK violations, then delete the groups themselves.
|
||||
default_group_ids = sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
op.execute(
|
||||
sa.delete(user__user_group_table).where(
|
||||
user__user_group_table.c.user_group_id.in_(default_group_ids)
|
||||
)
|
||||
)
|
||||
op.execute(
|
||||
sa.delete(user_group_table).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
)
|
||||
@@ -1,84 +0,0 @@
|
||||
"""grant_basic_to_existing_groups
|
||||
|
||||
Grants the "basic" permission to all existing groups that don't already
|
||||
have it. Every group should have at least "basic" so that its members
|
||||
get basic access when effective_permissions is backfilled.
|
||||
|
||||
Revision ID: b4b7e1028dfd
|
||||
Revises: b7bcc991d722
|
||||
Create Date: 2026-03-30 16:15:17.093498
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b4b7e1028dfd"
|
||||
down_revision = "b7bcc991d722"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_group = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
already_has_basic = (
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(permission_grant)
|
||||
.where(
|
||||
permission_grant.c.group_id == user_group.c.id,
|
||||
permission_grant.c.permission == "basic",
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
|
||||
groups_needing_basic = sa.select(
|
||||
user_group.c.id,
|
||||
sa.literal("basic").label("permission"),
|
||||
sa.literal("SYSTEM").label("grant_source"),
|
||||
sa.literal(False).label("is_deleted"),
|
||||
).where(
|
||||
user_group.c.is_default == sa.false(),
|
||||
~already_has_basic,
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.insert().from_select(
|
||||
["group_id", "permission", "grant_source", "is_deleted"],
|
||||
groups_needing_basic,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
non_default_group_ids = sa.select(user_group.c.id).where(
|
||||
user_group.c.is_default == sa.false()
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.delete().where(
|
||||
permission_grant.c.permission == "basic",
|
||||
permission_grant.c.grant_source == "SYSTEM",
|
||||
permission_grant.c.group_id.in_(non_default_group_ids),
|
||||
)
|
||||
)
|
||||
@@ -1,116 +0,0 @@
|
||||
"""assign_users_to_default_groups
|
||||
|
||||
Revision ID: b7bcc991d722
|
||||
Revises: 03d085c5c38d
|
||||
Create Date: 2026-03-25 16:30:39.529301
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7bcc991d722"
|
||||
down_revision = "03d085c5c38d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
sa.column("is_active", sa.Boolean),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Look up default group IDs
|
||||
admin_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Admin",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
basic_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Basic",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
if admin_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Admin' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
if basic_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Basic' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
# Users with role=admin → Admin group
|
||||
# Exclude inactive placeholder/anonymous users that are not real users
|
||||
admin_users = sa.select(
|
||||
sa.literal(admin_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.role == "ADMIN",
|
||||
user_table.c.is_active == True, # noqa: E712
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], admin_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
# STANDARD users (non-admin) and SERVICE_ACCOUNT users (role=basic) → Basic group
|
||||
# Exclude inactive placeholder/anonymous users that are not real users
|
||||
basic_users = sa.select(
|
||||
sa.literal(basic_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.is_active == True, # noqa: E712
|
||||
sa.or_(
|
||||
sa.and_(
|
||||
user_table.c.account_type == "STANDARD",
|
||||
user_table.c.role != "ADMIN",
|
||||
),
|
||||
sa.and_(
|
||||
user_table.c.account_type == "SERVICE_ACCOUNT",
|
||||
user_table.c.role == "BASIC",
|
||||
),
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], basic_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Group memberships are left in place — removing them risks
|
||||
# deleting memberships that existed before this migration.
|
||||
pass
|
||||
@@ -1,55 +0,0 @@
|
||||
"""add skipped to userfilestatus
|
||||
|
||||
Revision ID: d8cdfee5df80
|
||||
Revises: 1d78c0ca7853
|
||||
Create Date: 2026-04-01 10:47:12.593950
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d8cdfee5df80"
|
||||
down_revision = "1d78c0ca7853"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
TABLE = "user_file"
|
||||
COLUMN = "status"
|
||||
CONSTRAINT_NAME = "ck_user_file_status"
|
||||
|
||||
OLD_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
NEW_VALUES = (
|
||||
"PROCESSING",
|
||||
"INDEXING",
|
||||
"COMPLETED",
|
||||
"SKIPPED",
|
||||
"FAILED",
|
||||
"CANCELED",
|
||||
"DELETING",
|
||||
)
|
||||
|
||||
|
||||
def _drop_status_check_constraint() -> None:
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
for constraint in inspector.get_check_constraints(TABLE):
|
||||
if COLUMN in constraint.get("sqltext", ""):
|
||||
constraint_name = constraint["name"]
|
||||
if constraint_name is not None:
|
||||
op.drop_constraint(constraint_name, TABLE, type_="check")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_status_check_constraint()
|
||||
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(f"UPDATE {TABLE} SET {COLUMN} = 'COMPLETED' WHERE {COLUMN} = 'SKIPPED'")
|
||||
_drop_status_check_constraint()
|
||||
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
@@ -5,7 +5,6 @@ from onyx.background.celery.apps.primary import celery_app
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.hooks",
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cloud",
|
||||
|
||||
@@ -55,15 +55,6 @@ ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
|
||||
"schedule": timedelta(days=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
|
||||
@@ -13,7 +13,6 @@ from redis.lock import Lock as RedisLock
|
||||
from ee.onyx.server.tenants.provisioning import setup_tenant
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
|
||||
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
@@ -30,10 +29,9 @@ from shared_configs.configs import TENANT_ID_PREFIX
|
||||
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
|
||||
_MAX_TENANTS_PER_RUN = 5
|
||||
|
||||
# Time limits sized for worst-case: provisioning up to _MAX_TENANTS_PER_RUN new tenants
|
||||
# (~90s each) plus migrating up to TARGET_AVAILABLE_TENANTS pool tenants (~90s each).
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 20 # 20 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 25 # 25 minutes
|
||||
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -93,7 +91,8 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
|
||||
if batch_size < tenants_to_provision:
|
||||
task_logger.info(
|
||||
f"Capping batch to {batch_size} (need {tenants_to_provision}, will catch up next cycle)"
|
||||
f"Capping batch to {batch_size} "
|
||||
f"(need {tenants_to_provision}, will catch up next cycle)"
|
||||
)
|
||||
|
||||
provisioned = 0
|
||||
@@ -104,14 +103,12 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
provisioned += 1
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Failed to provision tenant {i + 1}/{batch_size}, continuing with remaining tenants"
|
||||
f"Failed to provision tenant {i + 1}/{batch_size}, "
|
||||
"continuing with remaining tenants"
|
||||
)
|
||||
|
||||
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
|
||||
|
||||
# Migrate any pool tenants that were provisioned before a new migration was deployed
|
||||
_migrate_stale_pool_tenants()
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
|
||||
@@ -124,46 +121,6 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
)
|
||||
|
||||
|
||||
def _migrate_stale_pool_tenants() -> None:
|
||||
"""
|
||||
Run alembic upgrade head on all pool tenants. Since alembic upgrade head is
|
||||
idempotent, tenants already at head are a fast no-op. This ensures pool
|
||||
tenants are always current so that signup doesn't hit schema mismatches
|
||||
(e.g. missing columns added after the tenant was pre-provisioned).
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
pool_tenants = db_session.query(AvailableTenant).all()
|
||||
tenant_ids = [t.tenant_id for t in pool_tenants]
|
||||
|
||||
if not tenant_ids:
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Checking {len(tenant_ids)} pool tenant(s) for pending migrations"
|
||||
)
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
try:
|
||||
run_alembic_migrations(tenant_id)
|
||||
new_version = get_current_alembic_version(tenant_id)
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant = (
|
||||
db_session.query(AvailableTenant)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.first()
|
||||
)
|
||||
if tenant and tenant.alembic_version != new_version:
|
||||
task_logger.info(
|
||||
f"Migrated pool tenant {tenant_id}: {tenant.alembic_version} -> {new_version}"
|
||||
)
|
||||
tenant.alembic_version = new_version
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Failed to migrate pool tenant {tenant_id}, skipping"
|
||||
)
|
||||
|
||||
|
||||
def pre_provision_tenant() -> bool:
|
||||
"""
|
||||
Pre-provision a new tenant and store it in the NewAvailableTenant table.
|
||||
|
||||
@@ -69,7 +69,5 @@ EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset(
|
||||
"/admin/token-rate-limits",
|
||||
# Evals
|
||||
"/evals",
|
||||
# Hook extensions
|
||||
"/admin/hooks",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -19,8 +19,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
@@ -30,7 +28,6 @@ from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
@@ -39,7 +36,6 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -259,7 +255,6 @@ def fetch_user_groups(
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -274,7 +269,6 @@ def fetch_user_groups(
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
include_default: If False, excludes system default groups (is_default=True).
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -282,8 +276,6 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -294,7 +286,6 @@ def fetch_user_groups_for_user(
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -304,8 +295,6 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -489,16 +478,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
|
||||
# Every group gets the "basic" permission by default
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=db_user_group.id,
|
||||
permission=Permission.BASIC_ACCESS,
|
||||
grant_source=GrantSource.SYSTEM,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
_add_user__user_group_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=db_user_group.id,
|
||||
@@ -510,9 +489,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
cc_pair_ids=user_group.cc_pair_ids,
|
||||
)
|
||||
|
||||
for uid in user_group.user_ids:
|
||||
recompute_user_permissions__no_commit(uid, db_session)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -820,9 +796,6 @@ def update_user_group(
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
for uid in set(added_user_ids) | set(removed_user_ids):
|
||||
recompute_user_permissions__no_commit(uid, db_session)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -862,17 +835,6 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
# Collect affected user IDs before cleanup deletes the relationships
|
||||
affected_user_ids = (
|
||||
db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
@@ -901,11 +863,6 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
|
||||
# Recompute permissions for affected users now that their
|
||||
# membership in this group has been removed
|
||||
for uid in affected_user_ids:
|
||||
recompute_user_permissions__no_commit(uid, db_session)
|
||||
|
||||
db_user_group.is_up_to_date = False
|
||||
db_user_group.is_up_for_deletion = True
|
||||
db_session.commit()
|
||||
|
||||
@@ -1,385 +0,0 @@
|
||||
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
|
||||
|
||||
Usage (Celery tasks and FastAPI handlers):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
if isinstance(result, HookSkipped):
|
||||
# no active hook configured — continue with original behavior
|
||||
...
|
||||
elif isinstance(result, HookSoftFailed):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is a validated Pydantic model instance (response_type)
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
--------------------------
|
||||
``is_reachable`` on the Hook row is updated selectively — only when the outcome
|
||||
carries meaningful signal about physical reachability:
|
||||
|
||||
NetworkError (DNS, connection refused) → False (cannot reach the server)
|
||||
HTTP 401 / 403 → False (api_key revoked or invalid)
|
||||
TimeoutException → None (server may be slow, skip write)
|
||||
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
|
||||
Unknown exception → None (no signal, skip write)
|
||||
Non-JSON / non-dict response → None (server responded, skip write)
|
||||
Success (2xx, valid dict) → True (confirmed reachable)
|
||||
|
||||
None means "leave the current value unchanged" — no DB round-trip is made.
|
||||
|
||||
DB session design
|
||||
-----------------
|
||||
The executor uses three sessions:
|
||||
|
||||
1. Caller's session (db_session) — used only for the hook lookup read. All
|
||||
needed fields are extracted from the Hook object before the HTTP call, so
|
||||
the caller's session is not held open during the external HTTP request.
|
||||
|
||||
2. Log session — a separate short-lived session opened after the HTTP call
|
||||
completes to write the HookExecutionLog row on failure. Success runs are
|
||||
not recorded. Committed independently of everything else.
|
||||
|
||||
3. Reachable session — a second short-lived session to update is_reachable on
|
||||
the Hook. Kept separate from the log session so a concurrent hook deletion
|
||||
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
|
||||
prevent the execution log from being written. This update is best-effort.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.hook import create_hook_execution_log__no_commit
|
||||
from onyx.db.hook import get_non_deleted_hook_by_hook_point
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _HttpOutcome(BaseModel):
|
||||
"""Structured result of an HTTP hook call, returned by _process_response."""
|
||||
|
||||
is_success: bool
|
||||
updated_is_reachable: (
|
||||
bool | None
|
||||
) # True/False = write to DB, None = unchanged (skip write)
|
||||
status_code: int | None
|
||||
error_message: str | None
|
||||
response_payload: dict[str, Any] | None
|
||||
|
||||
|
||||
def _lookup_hook(
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
) -> Hook | HookSkipped:
|
||||
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
|
||||
|
||||
No HTTP call is made and no DB writes are performed for any HookSkipped path.
|
||||
There is nothing to log and no reachability information to update.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
return HookSkipped()
|
||||
hook = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if hook is None or not hook.is_active:
|
||||
return HookSkipped()
|
||||
if not hook.endpoint_url:
|
||||
return HookSkipped()
|
||||
return hook
|
||||
|
||||
|
||||
def _process_response(
|
||||
*,
|
||||
response: httpx.Response | None,
|
||||
exc: Exception | None,
|
||||
timeout: float,
|
||||
) -> _HttpOutcome:
|
||||
"""Process the result of an HTTP call and return a structured outcome.
|
||||
|
||||
Called after the client.post() try/except. If post() raised, exc is set and
|
||||
response is None. Otherwise response is set and exc is None. Handles
|
||||
raise_for_status(), JSON decoding, and the dict shape check.
|
||||
"""
|
||||
if exc is not None:
|
||||
if isinstance(exc, httpx.NetworkError):
|
||||
msg = f"Hook network error (endpoint unreachable): {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False,
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
msg = f"Hook timed out after {timeout}s: {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # timeout doesn't indicate unreachability
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
msg = f"Hook call failed: {exc}"
|
||||
logger.exception(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # unknown error — don't make assumptions
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"exactly one of response or exc must be non-None; both are None"
|
||||
)
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
|
||||
# so the operator knows to update it. All other HTTP errors keep is_reachable
|
||||
# as-is (server is up, the request just failed for application reasons).
|
||||
auth_failed = e.response.status_code in (401, 403)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False if auth_failed else None,
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except (json.JSONDecodeError, httpx.DecodingError) as e:
|
||||
msg = f"Hook returned non-JSON response: {e}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
|
||||
logger.warning(msg)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
return _HttpOutcome(
|
||||
is_success=True,
|
||||
updated_is_reachable=True,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
response_payload=response_payload,
|
||||
)
|
||||
|
||||
|
||||
def _persist_result(
|
||||
*,
|
||||
hook_id: int,
|
||||
outcome: _HttpOutcome,
|
||||
duration_ms: int,
|
||||
) -> None:
|
||||
"""Write the execution log on failure and optionally update is_reachable, each
|
||||
in its own session so a failure in one does not affect the other."""
|
||||
# Only write the execution log on failure — success runs are not recorded.
|
||||
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
|
||||
# deleted between the initial lookup and here).
|
||||
if not outcome.is_success:
|
||||
try:
|
||||
with get_session_with_current_tenant() as log_session:
|
||||
create_hook_execution_log__no_commit(
|
||||
db_session=log_session,
|
||||
hook_id=hook_id,
|
||||
is_success=False,
|
||||
error_message=outcome.error_message,
|
||||
status_code=outcome.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
log_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to persist hook execution log for hook_id={hook_id}"
|
||||
)
|
||||
|
||||
# Update is_reachable separately — best-effort, non-critical.
|
||||
# None means the value is unchanged (set by the caller to skip the no-op write).
|
||||
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
|
||||
# concurrently deleted, so keep this isolated from the log write above.
|
||||
if outcome.updated_is_reachable is not None:
|
||||
try:
|
||||
with get_session_with_current_tenant() as reachable_session:
|
||||
update_hook__no_commit(
|
||||
db_session=reachable_session,
|
||||
hook_id=hook_id,
|
||||
is_reachable=outcome.updated_is_reachable,
|
||||
)
|
||||
reachable_session.commit()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _execute_hook_inner(
|
||||
hook: Hook,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSoftFailed:
|
||||
"""Make the HTTP call, validate the response, and return a typed model.
|
||||
|
||||
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
|
||||
"""
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
"active hooks without an endpoint_url must be rejected by _lookup_hook"
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
response: httpx.Response | None = None
|
||||
exc: Exception | None = None
|
||||
try:
|
||||
api_key: str | None = (
|
||||
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
)
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(
|
||||
timeout=timeout, follow_redirects=False
|
||||
) as client: # SSRF guard: never follow redirects
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
|
||||
# Validate the response payload against response_type.
|
||||
# A validation failure downgrades the outcome to a failure so it is logged,
|
||||
# is_reachable is left unchanged (server responded — just a bad payload),
|
||||
# and fail_strategy is respected below.
|
||||
validated_model: T | None = None
|
||||
if outcome.is_success and outcome.response_payload is not None:
|
||||
try:
|
||||
validated_model = response_type.model_validate(outcome.response_payload)
|
||||
except ValidationError as e:
|
||||
msg = (
|
||||
f"Hook response failed validation against {response_type.__name__}: {e}"
|
||||
)
|
||||
outcome = _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=outcome.status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
outcome = outcome.model_copy(update={"updated_is_reachable": None})
|
||||
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
|
||||
|
||||
if not outcome.is_success:
|
||||
if fail_strategy == HookFailStrategy.HARD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.HOOK_EXECUTION_FAILED,
|
||||
outcome.error_message or "Hook execution failed.",
|
||||
)
|
||||
logger.warning(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
|
||||
if validated_model is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"validated_model is None for successful hook call (hook_id={hook_id})",
|
||||
)
|
||||
return validated_model
|
||||
|
||||
|
||||
def _execute_hook_impl(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""EE implementation — loaded by CE's execute_hook via fetch_versioned_implementation.
|
||||
|
||||
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
|
||||
hook failed with SOFT fail strategy, or a validated response model on success.
|
||||
Raises OnyxError on HARD failure or if the hook is misconfigured.
|
||||
"""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
fail_strategy = hook.fail_strategy
|
||||
hook_id = hook.id
|
||||
|
||||
try:
|
||||
return _execute_hook_inner(hook, payload, response_type)
|
||||
except Exception:
|
||||
if fail_strategy == HookFailStrategy.SOFT:
|
||||
logger.exception(
|
||||
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
raise
|
||||
@@ -15,7 +15,6 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
basic_router as enterprise_settings_router,
|
||||
)
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.features.hooks.api import router as hook_router
|
||||
from ee.onyx.server.license.api import router as license_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.license_enforcement import (
|
||||
@@ -139,7 +138,6 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
|
||||
include_router_with_global_prefix_prepended(application, evals_router)
|
||||
include_router_with_global_prefix_prepended(application, hook_router)
|
||||
|
||||
# Enterprise-only global settings
|
||||
include_router_with_global_prefix_prepended(
|
||||
|
||||
@@ -52,13 +52,11 @@ from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
@@ -488,7 +486,6 @@ def create_user(
|
||||
email=email,
|
||||
hashed_password=_pw_helper.hash(_pw_helper.generate()),
|
||||
role=UserRole.BASIC,
|
||||
account_type=AccountType.STANDARD,
|
||||
is_active=user_resource.active,
|
||||
is_verified=True,
|
||||
personal_name=personal_name,
|
||||
@@ -509,25 +506,13 @@ def create_user(
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
# Assign user to default group BEFORE commit so everything is atomic.
|
||||
# If this fails, the entire user creation rolls back and IdP can retry.
|
||||
try:
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
except Exception:
|
||||
dal.rollback()
|
||||
logger.exception(f"Failed to assign SCIM user {email} to default groups")
|
||||
return _scim_error_response(
|
||||
500, f"Failed to assign user {email} to default group"
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
|
||||
@@ -99,26 +99,6 @@ async def get_or_provision_tenant(
|
||||
tenant_id = await get_available_tenant()
|
||||
|
||||
if tenant_id:
|
||||
# Run migrations to ensure the pre-provisioned tenant schema is current.
|
||||
# Pool tenants may have been created before a new migration was deployed.
|
||||
# Capture as a non-optional local so mypy can type the lambda correctly.
|
||||
_tenant_id: str = tenant_id
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None, lambda: run_alembic_migrations(_tenant_id)
|
||||
)
|
||||
except Exception:
|
||||
# The tenant was already dequeued from the pool — roll it back so
|
||||
# it doesn't end up orphaned (schema exists, but not assigned to anyone).
|
||||
logger.exception(
|
||||
f"Migration failed for pre-provisioned tenant {_tenant_id}; rolling back"
|
||||
)
|
||||
try:
|
||||
await rollback_tenant_provisioning(_tenant_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to rollback orphaned tenant {_tenant_id}")
|
||||
raise
|
||||
# If we have a pre-provisioned tenant, assign it to the user
|
||||
await assign_tenant_to_user(tenant_id, email, referral_source)
|
||||
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
|
||||
|
||||
@@ -43,16 +43,12 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
def list_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
@@ -60,50 +56,27 @@ def list_user_groups(
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@router.get("/user-groups/minimal")
|
||||
def list_minimal_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserGroupSnapshot]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
include_default=include_default,
|
||||
)
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [
|
||||
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
|
||||
]
|
||||
|
||||
|
||||
@router.get("/admin/user-group/{user_group_id}/permissions")
|
||||
def get_user_group_permissions(
|
||||
user_group_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[str]:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
|
||||
return [
|
||||
grant.permission.value
|
||||
for grant in group.permission_grants
|
||||
if not grant.is_deleted
|
||||
]
|
||||
|
||||
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
@@ -127,9 +100,6 @@ def rename_user_group_endpoint(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
group = fetch_user_group(db_session, rename_request.id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot rename a default system group.")
|
||||
try:
|
||||
return UserGroup.from_model(
|
||||
rename_user_group(
|
||||
@@ -215,9 +185,6 @@ def delete_user_group(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot delete a default system group.")
|
||||
try:
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -22,7 +22,6 @@ class UserGroup(BaseModel):
|
||||
personas: list[PersonaSnapshot]
|
||||
is_up_to_date: bool
|
||||
is_up_for_deletion: bool
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
|
||||
@@ -75,21 +74,18 @@ class UserGroup(BaseModel):
|
||||
],
|
||||
is_up_to_date=user_group_model.is_up_to_date,
|
||||
is_up_for_deletion=user_group_model.is_up_for_deletion,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
class MinimalUserGroupSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
|
||||
return cls(
|
||||
id=user_group_model.id,
|
||||
name=user_group_model.name,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -100,7 +100,6 @@ def get_model_app() -> FastAPI:
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
"""
|
||||
Permission resolution for group-based authorization.
|
||||
|
||||
Granted permissions are stored as a JSONB column on the User table and
|
||||
loaded for free with every auth query. Implied permissions are expanded
|
||||
at read time — only directly granted permissions are persisted.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ALL_PERMISSIONS: frozenset[str] = frozenset(p.value for p in Permission)
|
||||
|
||||
# Implication map: granted permission -> set of permissions it implies.
|
||||
IMPLIED_PERMISSIONS: dict[str, set[str]] = {
|
||||
Permission.ADD_AGENTS.value: {Permission.READ_AGENTS.value},
|
||||
Permission.MANAGE_AGENTS.value: {
|
||||
Permission.ADD_AGENTS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
},
|
||||
Permission.MANAGE_DOCUMENT_SETS.value: {
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value},
|
||||
Permission.MANAGE_CONNECTORS.value: {
|
||||
Permission.ADD_CONNECTORS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.MANAGE_USER_GROUPS.value: {
|
||||
Permission.READ_CONNECTORS.value,
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
Permission.READ_USERS.value,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def resolve_effective_permissions(granted: set[str]) -> set[str]:
|
||||
"""Expand granted permissions with their implied permissions.
|
||||
|
||||
If "admin" is present, returns all 19 permissions.
|
||||
"""
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS.value in granted:
|
||||
return set(ALL_PERMISSIONS)
|
||||
|
||||
effective = set(granted)
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for perm in list(effective):
|
||||
implied = IMPLIED_PERMISSIONS.get(perm)
|
||||
if implied and not implied.issubset(effective):
|
||||
effective |= implied
|
||||
changed = True
|
||||
return effective
|
||||
|
||||
|
||||
def get_effective_permissions(user: User) -> set[Permission]:
|
||||
"""Read granted permissions from the column and expand implied permissions."""
|
||||
granted: set[Permission] = set()
|
||||
for p in user.effective_permissions:
|
||||
try:
|
||||
granted.add(Permission(p))
|
||||
except ValueError:
|
||||
logger.warning(f"Skipping unknown permission '{p}' for user {user.id}")
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in granted:
|
||||
return set(Permission)
|
||||
expanded = resolve_effective_permissions({p.value for p in granted})
|
||||
return {Permission(p) for p in expanded}
|
||||
|
||||
|
||||
def require_permission(
|
||||
required: Permission,
|
||||
) -> Callable[..., Coroutine[Any, Any, User]]:
|
||||
"""FastAPI dependency factory for permission-based access control.
|
||||
|
||||
Usage:
|
||||
@router.get("/endpoint")
|
||||
def endpoint(user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS))):
|
||||
...
|
||||
"""
|
||||
|
||||
async def dependency(user: User = Depends(current_user)) -> User:
|
||||
effective = get_effective_permissions(user)
|
||||
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in effective:
|
||||
return user
|
||||
|
||||
if required not in effective:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"You do not have the required permissions for this action.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
return dependency
|
||||
@@ -5,8 +5,6 @@ from typing import Any
|
||||
from fastapi_users import schemas
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""
|
||||
@@ -43,7 +41,6 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
account_type: AccountType = AccountType.STANDARD
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
# Excluded from create_update_dict so it never reaches the DB layer
|
||||
@@ -53,16 +50,12 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
def create_update_dict(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict()
|
||||
d.pop("captcha_token", None)
|
||||
# Force STANDARD for self-registration; only trusted paths
|
||||
# (SCIM, API key creation) supply a different account_type directly.
|
||||
d["account_type"] = AccountType.STANDARD
|
||||
return d
|
||||
|
||||
@override
|
||||
def create_update_dict_superuser(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict_superuser()
|
||||
d.pop("captcha_token", None)
|
||||
d.setdefault("account_type", self.account_type)
|
||||
return d
|
||||
|
||||
|
||||
|
||||
@@ -120,13 +120,11 @@ from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
@@ -696,7 +694,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"email": account_email,
|
||||
"hashed_password": self.password_helper.hash(password),
|
||||
"is_verified": is_verified_by_default,
|
||||
"account_type": AccountType.STANDARD,
|
||||
}
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
@@ -746,23 +743,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
# Upgrade the user and assign default groups in a single
|
||||
# transaction so neither change is visible without the other.
|
||||
was_inactive = not user.is_active
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.is_verified = is_verified_by_default
|
||||
sync_user.role = UserRole.BASIC
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
if was_inactive:
|
||||
sync_user.is_active = True
|
||||
assign_user_to_default_groups__no_commit(sync_db, sync_user)
|
||||
sync_db.commit()
|
||||
|
||||
# Refresh the async user object so downstream code
|
||||
# (e.g. oidc_expiry check) sees the updated fields.
|
||||
user = await self.user_db.get(user.id) # type: ignore[arg-type]
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
**({"is_active": True} if not user.is_active else {}),
|
||||
},
|
||||
)
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
@@ -848,16 +836,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
event=MilestoneRecordType.TENANT_CREATED,
|
||||
)
|
||||
|
||||
# Assign user to the appropriate default group (Admin or Basic).
|
||||
# Must happen inside the try block while tenant context is active,
|
||||
# otherwise get_session_with_current_tenant() targets the wrong schema.
|
||||
is_admin = user_count == 1 or user.email in get_default_admin_user_emails()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=is_admin
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
@@ -1576,7 +1554,6 @@ def get_anonymous_user() -> User:
|
||||
is_verified=True,
|
||||
is_superuser=False,
|
||||
role=UserRole.LIMITED,
|
||||
account_type=AccountType.ANONYMOUS,
|
||||
use_memories=False,
|
||||
enable_memory_tool=False,
|
||||
)
|
||||
|
||||
@@ -20,7 +20,6 @@ from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
@@ -66,7 +65,6 @@ if SENTRY_DSN:
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
@@ -517,8 +515,7 @@ def reset_tenant_id(
|
||||
|
||||
|
||||
def wait_for_vespa_or_shutdown(
|
||||
sender: Any, # noqa: ARG001
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
sender: Any, **kwargs: Any # noqa: ARG001
|
||||
) -> None: # noqa: ARG001
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
@@ -317,6 +317,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.hooks",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
@@ -361,6 +362,19 @@ if not MULTI_TENANT:
|
||||
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
if HOOKS_AVAILABLE:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
|
||||
"schedule": timedelta(days=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def generate_cloud_tasks(
|
||||
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float
|
||||
|
||||
@@ -9,7 +9,6 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
|
||||
@@ -138,7 +137,6 @@ def _docfetching_task(
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -319,11 +319,6 @@ def monitor_indexing_attempt_progress(
|
||||
)
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
total_batches: int | str = (
|
||||
coordination_status.total_batches
|
||||
if coordination_status.total_batches is not None
|
||||
else "?"
|
||||
)
|
||||
if coordination_status.found:
|
||||
task_logger.info(
|
||||
f"Indexing attempt progress: "
|
||||
@@ -331,7 +326,7 @@ def monitor_indexing_attempt_progress(
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id} "
|
||||
f"completed_batches={coordination_status.completed_batches} "
|
||||
f"total_batches={total_batches} "
|
||||
f"total_batches={coordination_status.total_batches or '?'} "
|
||||
f"total_docs={coordination_status.total_docs} "
|
||||
f"total_failures={coordination_status.total_failures}"
|
||||
f"elapsed={(current_db_time - attempt.time_created).seconds}"
|
||||
@@ -415,7 +410,7 @@ def check_indexing_completion(
|
||||
logger.info(
|
||||
f"Indexing status: "
|
||||
f"indexing_completed={indexing_completed} "
|
||||
f"batches_processed={batches_processed}/{batches_total if batches_total is not None else '?'} "
|
||||
f"batches_processed={batches_processed}/{batches_total or '?'} "
|
||||
f"total_docs={coordination_status.total_docs} "
|
||||
f"total_chunks={coordination_status.total_chunks} "
|
||||
f"total_failures={coordination_status.total_failures}"
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Type alias for search doc deduplication key
|
||||
# Simple key: just document_id (str)
|
||||
@@ -159,114 +148,3 @@ class ChatStateContainer:
|
||||
"""Thread-safe getter for emitted citations (returns a copy)."""
|
||||
with self._lock:
|
||||
return self._emitted_citations.copy()
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
with event streaming capabilities.
|
||||
|
||||
The wrapped function should accept emitter as first arg and use it to emit
|
||||
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
pass
|
||||
"""
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
# Run the function in a background thread
|
||||
thread = run_in_background(run_with_exception_capture)
|
||||
|
||||
pkt: Packet | None = None
|
||||
last_turn_index = 0 # Track the highest turn_index seen for stop packet
|
||||
last_cancel_check = time.monotonic()
|
||||
cancel_check_interval = 0.3 # Check for cancellation every 300ms
|
||||
try:
|
||||
while True:
|
||||
# Poll queue with 300ms timeout for natural stop signal checking
|
||||
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
|
||||
try:
|
||||
pkt = emitter.bus.get(timeout=0.3)
|
||||
except Empty:
|
||||
if not is_connected():
|
||||
# Stop signal detected
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = time.monotonic()
|
||||
continue
|
||||
|
||||
if pkt is not None:
|
||||
# Track the highest turn_index for the stop packet
|
||||
if pkt.placement and pkt.placement.turn_index > last_turn_index:
|
||||
last_turn_index = pkt.placement.turn_index
|
||||
|
||||
if isinstance(pkt.obj, OverallStop):
|
||||
yield pkt
|
||||
break
|
||||
elif isinstance(pkt.obj, PacketException):
|
||||
raise pkt.obj.exception
|
||||
else:
|
||||
yield pkt
|
||||
|
||||
# Check for cancellation periodically even when packets are flowing
|
||||
# This ensures stop signal is checked during active streaming
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_cancel_check >= cancel_check_interval:
|
||||
if not is_connected():
|
||||
# Stop signal detected during streaming
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = current_time
|
||||
finally:
|
||||
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
if is_connected():
|
||||
wait_on_background(thread)
|
||||
try:
|
||||
completion_callback(state_container)
|
||||
except Exception as e:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,19 +1,40 @@
|
||||
import threading
|
||||
from queue import Queue
|
||||
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
class Emitter:
|
||||
"""Use this inside tools to emit arbitrary UI progress."""
|
||||
"""Routes packets from LLM/tool execution to the ``_run_models`` drain loop.
|
||||
|
||||
def __init__(self, bus: Queue):
|
||||
self.bus = bus
|
||||
Tags every packet with ``model_index`` and places it on ``merged_queue``
|
||||
as a ``(model_idx, packet)`` tuple for ordered consumption downstream.
|
||||
|
||||
Args:
|
||||
merged_queue: Shared queue owned by ``_run_models``.
|
||||
model_idx: Index embedded in packet placements (``0`` for N=1 runs).
|
||||
drain_done: Optional event set by ``_run_models`` when the drain loop
|
||||
exits early (e.g. HTTP disconnect). When set, ``emit`` returns
|
||||
immediately so worker threads can exit fast.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
merged_queue: Queue[tuple[int, Packet | Exception | object]],
|
||||
model_idx: int = 0,
|
||||
drain_done: threading.Event | None = None,
|
||||
) -> None:
|
||||
self._model_idx = model_idx
|
||||
self._merged_queue = merged_queue
|
||||
self._drain_done = drain_done
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
self.bus.put(packet) # Thread-safe
|
||||
|
||||
|
||||
def get_default_emitter() -> Emitter:
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
return emitter
|
||||
if self._drain_done is not None and self._drain_done.is_set():
|
||||
return
|
||||
base = packet.placement or Placement(turn_index=0)
|
||||
tagged = Packet(
|
||||
placement=base.model_copy(update={"model_index": self._model_idx}),
|
||||
obj=packet.obj,
|
||||
)
|
||||
self._merged_queue.put((self._model_idx, tagged))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1079,6 +1079,7 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
HOOK_ENABLED = os.environ.get("HOOK_ENABLED", "").lower() == "true"
|
||||
|
||||
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
|
||||
|
||||
|
||||
@@ -212,7 +212,6 @@ class DocumentSource(str, Enum):
|
||||
PRODUCTBOARD = "productboard"
|
||||
FILE = "file"
|
||||
CODA = "coda"
|
||||
CANVAS = "canvas"
|
||||
NOTION = "notion"
|
||||
ZULIP = "zulip"
|
||||
LINEAR = "linear"
|
||||
@@ -278,7 +277,6 @@ class NotificationType(str, Enum):
|
||||
RELEASE_NOTES = "release_notes"
|
||||
ASSISTANT_FILES_READY = "assistant_files_ready"
|
||||
FEATURE_ANNOUNCEMENT = "feature_announcement"
|
||||
USER_GROUP_ASSIGNMENT_FAILED = "user_group_assignment_failed"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@@ -674,7 +672,6 @@ DocumentSourceDescription: dict[DocumentSource, str] = {
|
||||
DocumentSource.SLAB: "slab data",
|
||||
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
|
||||
DocumentSource.FILE: "files",
|
||||
DocumentSource.CANVAS: "canvas lms - courses, pages, assignments, and announcements",
|
||||
DocumentSource.CODA: "coda - team workspace with docs, tables, and pages",
|
||||
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
|
||||
project management, and collaboration tools into a single, customizable platform",
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
"""
|
||||
Permissioning / AccessControl logic for Canvas courses.
|
||||
|
||||
CE stub — returns None (no permissions). The EE implementation is loaded
|
||||
at runtime via ``fetch_versioned_implementation``.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.canvas.client import CanvasApiClient
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
def get_course_permissions(
|
||||
canvas_client: CanvasApiClient,
|
||||
course_id: int,
|
||||
) -> ExternalAccess | None:
|
||||
if not global_version.is_ee_version():
|
||||
return None
|
||||
|
||||
ee_get_course_permissions = cast(
|
||||
Callable[[CanvasApiClient, int], ExternalAccess | None],
|
||||
fetch_versioned_implementation(
|
||||
"onyx.external_permissions.canvas.access",
|
||||
"get_course_permissions",
|
||||
),
|
||||
)
|
||||
|
||||
return ee_get_course_permissions(canvas_client, course_id)
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -191,22 +190,3 @@ class CanvasApiClient:
|
||||
if clean_endpoint:
|
||||
final_url += "/" + clean_endpoint
|
||||
return final_url
|
||||
|
||||
def paginate(
|
||||
self,
|
||||
endpoint: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Iterator[list[Any]]:
|
||||
"""Yield each page of results, following Link-header pagination.
|
||||
|
||||
Makes the first request with endpoint + params, then follows
|
||||
next_url from Link headers for subsequent pages.
|
||||
"""
|
||||
response, next_url = self.get(endpoint, params=params)
|
||||
while True:
|
||||
if not response:
|
||||
break
|
||||
yield response
|
||||
if not next_url:
|
||||
break
|
||||
response, next_url = self.get(full_url=next_url)
|
||||
|
||||
@@ -1,82 +1,17 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
from typing import NoReturn
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.canvas.access import get_course_permissions
|
||||
from onyx.connectors.canvas.client import CanvasApiClient
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
|
||||
"""Map Canvas API errors to connector framework exceptions."""
|
||||
if e.status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Canvas API token is invalid or expired (HTTP 401)."
|
||||
)
|
||||
elif e.status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Canvas API token does not have sufficient permissions (HTTP 403)."
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
|
||||
)
|
||||
elif e.status_code >= 500:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Canvas API error (status={e.status_code}): {e}"
|
||||
)
|
||||
|
||||
|
||||
class CanvasCourse(BaseModel):
|
||||
id: int
|
||||
name: str | None = None
|
||||
course_code: str | None = None
|
||||
created_at: str | None = None
|
||||
workflow_state: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any]) -> "CanvasCourse":
|
||||
return cls(
|
||||
id=payload["id"],
|
||||
name=payload.get("name"),
|
||||
course_code=payload.get("course_code"),
|
||||
created_at=payload.get("created_at"),
|
||||
workflow_state=payload.get("workflow_state"),
|
||||
)
|
||||
name: str
|
||||
course_code: str
|
||||
created_at: str
|
||||
workflow_state: str
|
||||
|
||||
|
||||
class CanvasPage(BaseModel):
|
||||
@@ -84,22 +19,10 @@ class CanvasPage(BaseModel):
|
||||
url: str
|
||||
title: str
|
||||
body: str | None = None
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
course_id: int
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasPage":
|
||||
return cls(
|
||||
page_id=payload["page_id"],
|
||||
url=payload["url"],
|
||||
title=payload["title"],
|
||||
body=payload.get("body"),
|
||||
created_at=payload.get("created_at"),
|
||||
updated_at=payload.get("updated_at"),
|
||||
course_id=course_id,
|
||||
)
|
||||
|
||||
|
||||
class CanvasAssignment(BaseModel):
|
||||
id: int
|
||||
@@ -107,23 +30,10 @@ class CanvasAssignment(BaseModel):
|
||||
description: str | None = None
|
||||
html_url: str
|
||||
course_id: int
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
due_at: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAssignment":
|
||||
return cls(
|
||||
id=payload["id"],
|
||||
name=payload["name"],
|
||||
description=payload.get("description"),
|
||||
html_url=payload["html_url"],
|
||||
course_id=course_id,
|
||||
created_at=payload.get("created_at"),
|
||||
updated_at=payload.get("updated_at"),
|
||||
due_at=payload.get("due_at"),
|
||||
)
|
||||
|
||||
|
||||
class CanvasAnnouncement(BaseModel):
|
||||
id: int
|
||||
@@ -133,17 +43,6 @@ class CanvasAnnouncement(BaseModel):
|
||||
posted_at: str | None = None
|
||||
course_id: int
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAnnouncement":
|
||||
return cls(
|
||||
id=payload["id"],
|
||||
title=payload["title"],
|
||||
message=payload.get("message"),
|
||||
html_url=payload["html_url"],
|
||||
posted_at=payload.get("posted_at"),
|
||||
course_id=course_id,
|
||||
)
|
||||
|
||||
|
||||
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
|
||||
|
||||
@@ -173,286 +72,3 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
|
||||
self.current_course_index += 1
|
||||
self.stage = "pages"
|
||||
self.next_url = None
|
||||
|
||||
|
||||
class CanvasConnector(
|
||||
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
canvas_base_url: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.canvas_base_url = canvas_base_url.rstrip("/").removesuffix("/api/v1")
|
||||
self.batch_size = batch_size
|
||||
self._canvas_client: CanvasApiClient | None = None
|
||||
self._course_permissions_cache: dict[int, ExternalAccess | None] = {}
|
||||
|
||||
@property
|
||||
def canvas_client(self) -> CanvasApiClient:
|
||||
if self._canvas_client is None:
|
||||
raise ConnectorMissingCredentialError("Canvas")
|
||||
return self._canvas_client
|
||||
|
||||
def _get_course_permissions(self, course_id: int) -> ExternalAccess | None:
|
||||
"""Get course permissions with caching."""
|
||||
if course_id not in self._course_permissions_cache:
|
||||
self._course_permissions_cache[course_id] = get_course_permissions(
|
||||
canvas_client=self.canvas_client,
|
||||
course_id=course_id,
|
||||
)
|
||||
return self._course_permissions_cache[course_id]
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_courses(self) -> list[CanvasCourse]:
|
||||
"""Fetch all courses accessible to the authenticated user."""
|
||||
logger.debug("Fetching Canvas courses")
|
||||
|
||||
courses: list[CanvasCourse] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
"courses", params={"per_page": "100", "state[]": "available"}
|
||||
):
|
||||
courses.extend(CanvasCourse.from_api(c) for c in page)
|
||||
return courses
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_pages(self, course_id: int) -> list[CanvasPage]:
|
||||
"""Fetch all pages for a given course."""
|
||||
logger.debug(f"Fetching pages for course {course_id}")
|
||||
|
||||
pages: list[CanvasPage] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
f"courses/{course_id}/pages",
|
||||
params={"per_page": "100", "include[]": "body", "published": "true"},
|
||||
):
|
||||
pages.extend(CanvasPage.from_api(p, course_id=course_id) for p in page)
|
||||
return pages
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_assignments(self, course_id: int) -> list[CanvasAssignment]:
|
||||
"""Fetch all assignments for a given course."""
|
||||
logger.debug(f"Fetching assignments for course {course_id}")
|
||||
|
||||
assignments: list[CanvasAssignment] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
f"courses/{course_id}/assignments",
|
||||
params={"per_page": "100", "published": "true"},
|
||||
):
|
||||
assignments.extend(
|
||||
CanvasAssignment.from_api(a, course_id=course_id) for a in page
|
||||
)
|
||||
return assignments
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_announcements(self, course_id: int) -> list[CanvasAnnouncement]:
|
||||
"""Fetch all announcements for a given course."""
|
||||
logger.debug(f"Fetching announcements for course {course_id}")
|
||||
|
||||
announcements: list[CanvasAnnouncement] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
"announcements",
|
||||
params={
|
||||
"per_page": "100",
|
||||
"context_codes[]": f"course_{course_id}",
|
||||
"active_only": "true",
|
||||
},
|
||||
):
|
||||
announcements.extend(
|
||||
CanvasAnnouncement.from_api(a, course_id=course_id) for a in page
|
||||
)
|
||||
return announcements
|
||||
|
||||
def _build_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
link: str,
|
||||
text: str,
|
||||
semantic_identifier: str,
|
||||
doc_updated_at: datetime | None,
|
||||
course_id: int,
|
||||
doc_type: str,
|
||||
) -> Document:
|
||||
"""Build a Document with standard Canvas fields."""
|
||||
return Document(
|
||||
id=doc_id,
|
||||
sections=cast(
|
||||
list[TextSection | ImageSection],
|
||||
[TextSection(link=link, text=text)],
|
||||
),
|
||||
source=DocumentSource.CANVAS,
|
||||
semantic_identifier=semantic_identifier,
|
||||
doc_updated_at=doc_updated_at,
|
||||
metadata={"course_id": str(course_id), "type": doc_type},
|
||||
)
|
||||
|
||||
def _convert_page_to_document(self, page: CanvasPage) -> Document:
|
||||
"""Convert a Canvas page to a Document."""
|
||||
link = f"{self.canvas_base_url}/courses/{page.course_id}/pages/{page.url}"
|
||||
|
||||
text_parts = [page.title]
|
||||
body_text = parse_html_page_basic(page.body) if page.body else ""
|
||||
if body_text:
|
||||
text_parts.append(body_text)
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
if page.updated_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
|
||||
link=link,
|
||||
text="\n\n".join(text_parts),
|
||||
semantic_identifier=page.title or f"Page {page.page_id}",
|
||||
doc_updated_at=doc_updated_at,
|
||||
course_id=page.course_id,
|
||||
doc_type="page",
|
||||
)
|
||||
return document
|
||||
|
||||
def _convert_assignment_to_document(self, assignment: CanvasAssignment) -> Document:
|
||||
"""Convert a Canvas assignment to a Document."""
|
||||
text_parts = [assignment.name]
|
||||
desc_text = (
|
||||
parse_html_page_basic(assignment.description)
|
||||
if assignment.description
|
||||
else ""
|
||||
)
|
||||
if desc_text:
|
||||
text_parts.append(desc_text)
|
||||
if assignment.due_at:
|
||||
due_dt = datetime.fromisoformat(
|
||||
assignment.due_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(
|
||||
assignment.updated_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
if assignment.updated_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-assignment-{assignment.course_id}-{assignment.id}",
|
||||
link=assignment.html_url,
|
||||
text="\n\n".join(text_parts),
|
||||
semantic_identifier=assignment.name or f"Assignment {assignment.id}",
|
||||
doc_updated_at=doc_updated_at,
|
||||
course_id=assignment.course_id,
|
||||
doc_type="assignment",
|
||||
)
|
||||
return document
|
||||
|
||||
def _convert_announcement_to_document(
|
||||
self, announcement: CanvasAnnouncement
|
||||
) -> Document:
|
||||
"""Convert a Canvas announcement to a Document."""
|
||||
text_parts = [announcement.title]
|
||||
msg_text = (
|
||||
parse_html_page_basic(announcement.message) if announcement.message else ""
|
||||
)
|
||||
if msg_text:
|
||||
text_parts.append(msg_text)
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(
|
||||
announcement.posted_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
if announcement.posted_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-announcement-{announcement.course_id}-{announcement.id}",
|
||||
link=announcement.html_url,
|
||||
text="\n\n".join(text_parts),
|
||||
semantic_identifier=announcement.title or f"Announcement {announcement.id}",
|
||||
doc_updated_at=doc_updated_at,
|
||||
course_id=announcement.course_id,
|
||||
doc_type="announcement",
|
||||
)
|
||||
return document
|
||||
|
||||
@override
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load and validate Canvas credentials."""
|
||||
access_token = credentials.get("canvas_access_token")
|
||||
if not access_token:
|
||||
raise ConnectorMissingCredentialError("Canvas")
|
||||
|
||||
try:
|
||||
client = CanvasApiClient(
|
||||
bearer_token=access_token,
|
||||
canvas_base_url=self.canvas_base_url,
|
||||
)
|
||||
client.get("courses", params={"per_page": "1"})
|
||||
except ValueError as e:
|
||||
raise ConnectorValidationError(f"Invalid Canvas base URL: {e}")
|
||||
except OnyxError as e:
|
||||
_handle_canvas_api_error(e)
|
||||
|
||||
self._canvas_client = client
|
||||
return None
|
||||
|
||||
@override
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Canvas connector settings by testing API access."""
|
||||
try:
|
||||
self.canvas_client.get("courses", params={"per_page": "1"})
|
||||
logger.info("Canvas connector settings validated successfully")
|
||||
except OnyxError as e:
|
||||
_handle_canvas_api_error(e)
|
||||
except ConnectorMissingCredentialError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected error during Canvas settings validation: {exc}"
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> CanvasConnectorCheckpoint:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
# TODO(benwu408): implemented in PR4 (perm sync)
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -11,13 +11,11 @@ from discord import Client
|
||||
from discord.channel import TextChannel
|
||||
from discord.channel import Thread
|
||||
from discord.enums import MessageType
|
||||
from discord.errors import LoginFailure
|
||||
from discord.flags import Intents
|
||||
from discord.message import Message as DiscordMessage
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import CredentialInvalidError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -211,19 +209,8 @@ def _manage_async_retrieval(
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents) as discord_client:
|
||||
start_task = asyncio.create_task(discord_client.start(token))
|
||||
ready_task = asyncio.create_task(discord_client.wait_until_ready())
|
||||
|
||||
done, _ = await asyncio.wait(
|
||||
{start_task, ready_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
# start() runs indefinitely once connected, so it only lands
|
||||
# in `done` when login/connection failed — propagate the error.
|
||||
if start_task in done:
|
||||
ready_task.cancel()
|
||||
start_task.result()
|
||||
asyncio.create_task(discord_client.start(token))
|
||||
await discord_client.wait_until_ready()
|
||||
|
||||
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
|
||||
discord_client=discord_client,
|
||||
@@ -289,19 +276,6 @@ class DiscordConnector(PollConnector, LoadConnector):
|
||||
self._discord_bot_token = credentials["discord_bot_token"]
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
client = Client(intents=Intents.default())
|
||||
try:
|
||||
loop.run_until_complete(client.login(self.discord_bot_token))
|
||||
except LoginFailure as e:
|
||||
raise CredentialInvalidError(f"Invalid Discord bot token: {e}")
|
||||
finally:
|
||||
loop.run_until_complete(client.close())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def _manage_doc_batching(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
|
||||
@@ -72,10 +72,6 @@ CONNECTOR_CLASS_MAP = {
|
||||
module_path="onyx.connectors.coda.connector",
|
||||
class_name="CodaConnector",
|
||||
),
|
||||
DocumentSource.CANVAS: ConnectorMapping(
|
||||
module_path="onyx.connectors.canvas.connector",
|
||||
class_name="CanvasConnector",
|
||||
),
|
||||
DocumentSource.NOTION: ConnectorMapping(
|
||||
module_path="onyx.connectors.notion.connector",
|
||||
class_name="NotionConnector",
|
||||
|
||||
@@ -11,19 +11,14 @@ from onyx.auth.api_key import ApiKeyDescriptor
|
||||
from onyx.auth.api_key import build_displayable_api_key
|
||||
from onyx.auth.api_key import generate_api_key
|
||||
from onyx.auth.api_key import hash_api_key
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_api_key_email_pattern() -> str:
|
||||
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
@@ -92,7 +87,6 @@ def insert_api_key(
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=api_key_args.role,
|
||||
account_type=AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
db_session.add(api_key_user_row)
|
||||
|
||||
@@ -105,21 +99,7 @@ def insert_api_key(
|
||||
)
|
||||
db_session.add(api_key_row)
|
||||
|
||||
# Assign the API key virtual user to the appropriate default group
|
||||
# before commit so everything is atomic.
|
||||
# LIMITED role service accounts should have no group membership.
|
||||
# Late import to avoid circular dependency (api_key <- users <- api_key).
|
||||
if api_key_args.role != UserRole.LIMITED:
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user_row,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
api_key_id=api_key_row.id,
|
||||
api_key_role=api_key_user_row.role,
|
||||
|
||||
@@ -8,6 +8,7 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
@@ -131,47 +132,32 @@ def get_chat_sessions_by_user(
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(ChatSession.project_id == project_id)
|
||||
elif only_non_project_chats:
|
||||
stmt = stmt.where(ChatSession.project_id.is_(None))
|
||||
|
||||
# When filtering out failed chats, we apply the limit in Python after
|
||||
# filtering rather than in SQL, since the post-filter may remove rows.
|
||||
if limit and include_failed_chats:
|
||||
stmt = stmt.limit(limit)
|
||||
if not include_failed_chats:
|
||||
non_system_message_exists_subq = (
|
||||
exists()
|
||||
.where(ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.correlate(ChatSession)
|
||||
)
|
||||
|
||||
# Leeway for newly created chats that don't have messages yet
|
||||
time = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
recently_created = ChatSession.time_created >= time
|
||||
|
||||
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = list(result.scalars().all())
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
if not include_failed_chats and chat_sessions:
|
||||
# Filter out "failed" sessions (those with only SYSTEM messages)
|
||||
# using a separate efficient query instead of a correlated EXISTS
|
||||
# subquery, which causes full sequential scans of chat_message.
|
||||
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
|
||||
|
||||
if session_ids:
|
||||
valid_session_ids_stmt = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.where(ChatMessage.chat_session_id.in_(session_ids))
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.distinct()
|
||||
)
|
||||
valid_session_ids = set(
|
||||
db_session.execute(valid_session_ids_stmt).scalars().all()
|
||||
)
|
||||
|
||||
chat_sessions = [
|
||||
cs
|
||||
for cs in chat_sessions
|
||||
if cs.time_created >= leeway or cs.id in valid_session_ids
|
||||
]
|
||||
|
||||
if limit:
|
||||
chat_sessions = chat_sessions[:limit]
|
||||
|
||||
return chat_sessions
|
||||
return list(chat_sessions)
|
||||
|
||||
|
||||
def delete_orphaned_search_docs(db_session: Session) -> None:
|
||||
@@ -631,6 +617,92 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
|
||||
|
||||
Also advances ``latest_child_message_id`` so the preferred response becomes
|
||||
the active branch for any subsequent messages in the conversation.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
|
||||
preferred response is being set.
|
||||
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
|
||||
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
|
||||
|
||||
Raises:
|
||||
ValueError: If either message is not found, if ``user_message_id`` does not
|
||||
refer to a USER message, or if the assistant message is not a direct child
|
||||
of the user message.
|
||||
"""
|
||||
user_msg = db_session.get(ChatMessage, user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child "
|
||||
f"of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
user_msg.latest_child_message_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -853,6 +925,8 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -13,19 +13,19 @@ class AccountType(str, PyEnum):
|
||||
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
|
||||
"""
|
||||
|
||||
STANDARD = "STANDARD"
|
||||
BOT = "BOT"
|
||||
EXT_PERM_USER = "EXT_PERM_USER"
|
||||
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
ANONYMOUS = "ANONYMOUS"
|
||||
STANDARD = "standard"
|
||||
BOT = "bot"
|
||||
EXT_PERM_USER = "ext_perm_user"
|
||||
SERVICE_ACCOUNT = "service_account"
|
||||
ANONYMOUS = "anonymous"
|
||||
|
||||
|
||||
class GrantSource(str, PyEnum):
|
||||
"""How a permission grant was created."""
|
||||
|
||||
USER = "USER"
|
||||
SCIM = "SCIM"
|
||||
SYSTEM = "SYSTEM"
|
||||
USER = "user"
|
||||
SCIM = "scim"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
@@ -215,7 +215,6 @@ class UserFileStatus(str, PyEnum):
|
||||
PROCESSING = "PROCESSING"
|
||||
INDEXING = "INDEXING"
|
||||
COMPLETED = "COMPLETED"
|
||||
SKIPPED = "SKIPPED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
DELETING = "DELETING"
|
||||
|
||||
@@ -305,11 +305,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
account_type: Mapped[AccountType] = mapped_column(
|
||||
Enum(AccountType, native_enum=False),
|
||||
nullable=False,
|
||||
default=AccountType.STANDARD,
|
||||
server_default="STANDARD",
|
||||
account_type: Mapped[AccountType | None] = mapped_column(
|
||||
Enum(AccountType, native_enum=False), nullable=True
|
||||
)
|
||||
|
||||
"""
|
||||
@@ -356,13 +353,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
|
||||
effective_permissions: Mapped[list[str]] = mapped_column(
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
default=list,
|
||||
server_default=text("'[]'::jsonb"),
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
@@ -4026,12 +4016,7 @@ class PermissionGrant(Base):
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
permission: Mapped[Permission] = mapped_column(
|
||||
Enum(
|
||||
Permission,
|
||||
native_enum=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
Enum(Permission, native_enum=False), nullable=False
|
||||
)
|
||||
grant_source: Mapped[GrantSource] = mapped_column(
|
||||
Enum(GrantSource, native_enum=False), nullable=False
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
@@ -91,18 +90,9 @@ def get_notifications(
|
||||
notif_type: NotificationType | None = None,
|
||||
include_dismissed: bool = True,
|
||||
) -> list[Notification]:
|
||||
if user is None:
|
||||
user_filter = Notification.user_id.is_(None)
|
||||
elif user.role == UserRole.ADMIN:
|
||||
# Admins see their own notifications AND admin-targeted ones (user_id IS NULL)
|
||||
user_filter = or_(
|
||||
Notification.user_id == user.id,
|
||||
Notification.user_id.is_(None),
|
||||
)
|
||||
else:
|
||||
user_filter = Notification.user_id == user.id
|
||||
|
||||
query = select(Notification).where(user_filter)
|
||||
query = select(Notification).where(
|
||||
Notification.user_id == user.id if user else Notification.user_id.is_(None)
|
||||
)
|
||||
if not include_dismissed:
|
||||
query = query.where(Notification.dismissed.is_(False))
|
||||
if notif_type:
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
"""
|
||||
DB operations for recomputing user effective_permissions.
|
||||
|
||||
These live in onyx/db/ (not onyx/auth/) because they are pure DB operations
|
||||
that query PermissionGrant rows and update the User.effective_permissions
|
||||
JSONB column. Keeping them here avoids circular imports when called from
|
||||
other onyx/db/ modules such as users.py.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
|
||||
|
||||
def recompute_user_permissions__no_commit(user_id: UUID, db_session: Session) -> None:
|
||||
"""Recompute a single user's granted permissions from their group grants.
|
||||
|
||||
Stores only directly granted permissions — implication expansion
|
||||
happens at read time via get_effective_permissions().
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
stmt = (
|
||||
select(PermissionGrant.permission)
|
||||
.join(
|
||||
User__UserGroup,
|
||||
PermissionGrant.group_id == User__UserGroup.user_group_id,
|
||||
)
|
||||
.where(
|
||||
User__UserGroup.user_id == user_id,
|
||||
PermissionGrant.is_deleted.is_(False),
|
||||
)
|
||||
)
|
||||
rows = db_session.execute(stmt).scalars().all()
|
||||
# sorted for consistent ordering in DB — easier to read when debugging
|
||||
granted = sorted({p.value for p in rows})
|
||||
|
||||
db_session.execute(
|
||||
update(User).where(User.id == user_id).values(effective_permissions=granted)
|
||||
)
|
||||
|
||||
|
||||
def recompute_permissions_for_group__no_commit(
|
||||
group_id: int, db_session: Session
|
||||
) -> None:
|
||||
"""Recompute granted permissions for all users in a group.
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
user_ids: list[UUID] = list(
|
||||
db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == group_id,
|
||||
User__UserGroup.user_id.isnot(None),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
# Single query to fetch ALL permissions for these users across ALL their
|
||||
# groups (a user may belong to multiple groups with different grants).
|
||||
rows = db_session.execute(
|
||||
select(User__UserGroup.user_id, PermissionGrant.permission)
|
||||
.join(
|
||||
PermissionGrant,
|
||||
PermissionGrant.group_id == User__UserGroup.user_group_id,
|
||||
)
|
||||
.where(
|
||||
User__UserGroup.user_id.in_(user_ids),
|
||||
PermissionGrant.is_deleted.is_(False),
|
||||
)
|
||||
).all()
|
||||
|
||||
# Group permissions by user; users with no grants get an empty set.
|
||||
perms_by_user: dict[UUID, set[str]] = defaultdict(set)
|
||||
for uid in user_ids:
|
||||
perms_by_user[uid] # ensure every user has an entry
|
||||
for uid, perm in rows:
|
||||
perms_by_user[uid].add(perm.value)
|
||||
|
||||
for uid, perms in perms_by_user.items():
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == uid)
|
||||
.values(effective_permissions=sorted(perms))
|
||||
)
|
||||
@@ -7,7 +7,6 @@ from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
@@ -18,7 +17,6 @@ from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
@@ -36,19 +34,9 @@ class CategorizedFilesResult(BaseModel):
|
||||
user_files: list[UserFile]
|
||||
rejected_files: list[RejectedFile]
|
||||
id_to_temp_id: dict[str, str]
|
||||
# Filenames that should be stored but not indexed.
|
||||
skip_indexing_filenames: set[str] = Field(default_factory=set)
|
||||
# Allow SQLAlchemy ORM models inside this result container
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@property
|
||||
def indexable_files(self) -> list[UserFile]:
|
||||
return [
|
||||
uf
|
||||
for uf in self.user_files
|
||||
if (uf.name or "") not in self.skip_indexing_filenames
|
||||
]
|
||||
|
||||
|
||||
def build_hashed_file_key(file: UploadFile) -> str:
|
||||
name_prefix = (file.filename or "")[:50]
|
||||
@@ -82,7 +70,6 @@ def create_user_files(
|
||||
)
|
||||
if new_temp_id is not None:
|
||||
id_to_temp_id[str(new_id)] = new_temp_id
|
||||
should_skip = (file.filename or "") in categorized_files.skip_indexing
|
||||
new_file = UserFile(
|
||||
id=new_id,
|
||||
user_id=user.id,
|
||||
@@ -94,7 +81,6 @@ def create_user_files(
|
||||
link_url=link_url,
|
||||
content_type=file.content_type,
|
||||
file_type=file.content_type,
|
||||
status=UserFileStatus.SKIPPED if should_skip else UserFileStatus.PROCESSING,
|
||||
last_accessed_at=datetime.datetime.now(datetime.timezone.utc),
|
||||
)
|
||||
# Persist the UserFile first to satisfy FK constraints for association table
|
||||
@@ -112,7 +98,6 @@ def create_user_files(
|
||||
user_files=user_files,
|
||||
rejected_files=rejected_files,
|
||||
id_to_temp_id=id_to_temp_id,
|
||||
skip_indexing_filenames=categorized_files.skip_indexing,
|
||||
)
|
||||
|
||||
|
||||
@@ -138,7 +123,6 @@ def upload_files_to_user_files_with_indexing(
|
||||
user_files = categorized_files_result.user_files
|
||||
rejected_files = categorized_files_result.rejected_files
|
||||
id_to_temp_id = categorized_files_result.id_to_temp_id
|
||||
indexable_files = categorized_files_result.indexable_files
|
||||
# Trigger per-file processing immediately for the current tenant
|
||||
tenant_id = get_current_tenant_id()
|
||||
for rejected_file in rejected_files:
|
||||
@@ -150,12 +134,12 @@ def upload_files_to_user_files_with_indexing(
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
|
||||
background_tasks.add_task(drain_processing_loop, tenant_id)
|
||||
for user_file in indexable_files:
|
||||
for user_file in user_files:
|
||||
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
for user_file in indexable_files:
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
@@ -171,7 +155,6 @@ def upload_files_to_user_files_with_indexing(
|
||||
user_files=user_files,
|
||||
rejected_files=rejected_files,
|
||||
id_to_temp_id=id_to_temp_id,
|
||||
skip_indexing_filenames=categorized_files_result.skip_indexing_filenames,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
@@ -28,11 +27,8 @@ from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def validate_user_role_update(
|
||||
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
|
||||
@@ -302,7 +298,6 @@ def _generate_slack_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.SLACK_USER,
|
||||
account_type=AccountType.BOT,
|
||||
)
|
||||
|
||||
|
||||
@@ -313,7 +308,6 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
|
||||
# If the user is an external permissioned user, we update it to a slack user
|
||||
if user.role == UserRole.EXT_PERM_USER:
|
||||
user.role = UserRole.SLACK_USER
|
||||
user.account_type = AccountType.BOT
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
@@ -350,7 +344,6 @@ def _generate_ext_permissioned_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
@@ -382,81 +375,6 @@ def batch_add_ext_perm_user_if_not_exists(
|
||||
return all_users
|
||||
|
||||
|
||||
def assign_user_to_default_groups__no_commit(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
is_admin: bool = False,
|
||||
) -> None:
|
||||
"""Assign a newly created user to the appropriate default group.
|
||||
|
||||
Does NOT commit — callers must commit the session themselves so that
|
||||
group assignment can be part of the same transaction as user creation.
|
||||
|
||||
Args:
|
||||
is_admin: If True, assign to Admin default group; otherwise Basic.
|
||||
Callers determine this from their own context (e.g. user_count,
|
||||
admin email list, explicit choice). Defaults to False (Basic).
|
||||
"""
|
||||
if user.account_type in (
|
||||
AccountType.BOT,
|
||||
AccountType.EXT_PERM_USER,
|
||||
AccountType.ANONYMOUS,
|
||||
):
|
||||
return
|
||||
|
||||
target_group_name = "Admin" if is_admin else "Basic"
|
||||
|
||||
default_group = (
|
||||
db_session.query(UserGroup)
|
||||
.filter(
|
||||
UserGroup.name == target_group_name,
|
||||
UserGroup.is_default.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if default_group is None:
|
||||
raise RuntimeError(
|
||||
f"Default group '{target_group_name}' not found. "
|
||||
f"Cannot assign user {user.email} to a group. "
|
||||
f"Ensure the seed_default_groups migration has run."
|
||||
)
|
||||
|
||||
# Check if the user is already in the group
|
||||
existing = (
|
||||
db_session.query(User__UserGroup)
|
||||
.filter(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id == default_group.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing is not None:
|
||||
return
|
||||
|
||||
savepoint = db_session.begin_nested()
|
||||
try:
|
||||
db_session.add(
|
||||
User__UserGroup(
|
||||
user_id=user.id,
|
||||
user_group_id=default_group.id,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
except IntegrityError:
|
||||
# Race condition: another transaction inserted this membership
|
||||
# between our SELECT and INSERT. The savepoint isolates the failure
|
||||
# so the outer transaction (user creation) stays intact.
|
||||
savepoint.rollback()
|
||||
return
|
||||
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
logger.info(f"Assigned user {user.email} to default group '{default_group.name}'")
|
||||
|
||||
|
||||
def delete_user_from_db(
|
||||
user_to_delete: User,
|
||||
db_session: Session,
|
||||
@@ -503,14 +421,13 @@ def delete_user_from_db(
|
||||
def batch_get_user_groups(
|
||||
db_session: Session,
|
||||
user_ids: list[UUID],
|
||||
include_default: bool = False,
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Fetch group memberships for a batch of users in a single query.
|
||||
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
stmt = (
|
||||
rows = db_session.execute(
|
||||
select(
|
||||
User__UserGroup.user_id,
|
||||
UserGroup.id,
|
||||
@@ -518,11 +435,7 @@ def batch_get_user_groups(
|
||||
)
|
||||
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id.in_(user_ids))
|
||||
)
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
|
||||
rows = db_session.execute(stmt).all()
|
||||
).all()
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
|
||||
for user_id, group_id, group_name in rows:
|
||||
|
||||
@@ -932,7 +932,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
def search_for_document_ids(
|
||||
self,
|
||||
body: dict[str, Any],
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN,
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.DOCUMENT_IDS,
|
||||
) -> list[str]:
|
||||
"""Searches the index and returns only document chunk IDs.
|
||||
|
||||
|
||||
@@ -60,7 +60,8 @@ class OpenSearchSearchType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
RANDOM = "random"
|
||||
DOC_ID_RETRIEVAL = "doc_id_retrieval"
|
||||
ID_RETRIEVAL = "id_retrieval"
|
||||
DOCUMENT_IDS = "document_ids"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
|
||||
@@ -928,7 +928,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.DOC_ID_RETRIEVAL,
|
||||
search_type=OpenSearchSearchType.ID_RETRIEVAL,
|
||||
)
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
|
||||
@@ -15,7 +15,6 @@ PLAIN_TEXT_MIME_TYPE = "text/plain"
|
||||
class OnyxMimeTypes:
|
||||
IMAGE_MIME_TYPES = {"image/jpg", "image/jpeg", "image/png", "image/webp"}
|
||||
CSV_MIME_TYPES = {"text/csv"}
|
||||
TABULAR_MIME_TYPES = CSV_MIME_TYPES | {SPREADSHEET_MIME_TYPE}
|
||||
TEXT_MIME_TYPES = {
|
||||
PLAIN_TEXT_MIME_TYPE,
|
||||
"text/markdown",
|
||||
@@ -35,12 +34,13 @@ class OnyxMimeTypes:
|
||||
PDF_MIME_TYPE,
|
||||
WORD_PROCESSING_MIME_TYPE,
|
||||
PRESENTATION_MIME_TYPE,
|
||||
SPREADSHEET_MIME_TYPE,
|
||||
"message/rfc822",
|
||||
"application/epub+zip",
|
||||
}
|
||||
|
||||
ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(
|
||||
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, TABULAR_MIME_TYPES
|
||||
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES
|
||||
)
|
||||
|
||||
EXCLUDED_IMAGE_TYPES = {
|
||||
@@ -53,11 +53,6 @@ class OnyxMimeTypes:
|
||||
|
||||
|
||||
class OnyxFileExtensions:
|
||||
TABULAR_EXTENSIONS = {
|
||||
".csv",
|
||||
".tsv",
|
||||
".xlsx",
|
||||
}
|
||||
PLAIN_TEXT_EXTENSIONS = {
|
||||
".txt",
|
||||
".md",
|
||||
|
||||
@@ -13,14 +13,13 @@ class ChatFileType(str, Enum):
|
||||
DOC = "document"
|
||||
# Plain text only contain the text
|
||||
PLAIN_TEXT = "plain_text"
|
||||
# Tabular data files (CSV, XLSX)
|
||||
TABULAR = "tabular"
|
||||
CSV = "csv"
|
||||
|
||||
def is_text_file(self) -> bool:
|
||||
return self in (
|
||||
ChatFileType.PLAIN_TEXT,
|
||||
ChatFileType.DOC,
|
||||
ChatFileType.TABULAR,
|
||||
ChatFileType.CSV,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -6,7 +7,10 @@ from shared_configs.configs import MULTI_TENANT
|
||||
def require_hook_enabled() -> None:
|
||||
"""FastAPI dependency that gates all hook management endpoints.
|
||||
|
||||
Hooks are only available in single-tenant / self-hosted EE deployments.
|
||||
Hooks are only available in single-tenant / self-hosted deployments with
|
||||
HOOK_ENABLED=true explicitly set. Two layers of protection:
|
||||
1. MULTI_TENANT check — rejects even if HOOK_ENABLED is accidentally set true
|
||||
2. HOOK_ENABLED flag — explicit opt-in by the operator
|
||||
|
||||
Use as: Depends(require_hook_enabled)
|
||||
"""
|
||||
@@ -15,3 +19,8 @@ def require_hook_enabled() -> None:
|
||||
OnyxErrorCode.SINGLE_TENANT_ONLY,
|
||||
"Hooks are not available in multi-tenant deployments",
|
||||
)
|
||||
if not HOOK_ENABLED:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.ENV_VAR_GATED,
|
||||
"Hooks are not enabled. Set HOOK_ENABLED=true to enable.",
|
||||
)
|
||||
|
||||
@@ -1,22 +1,79 @@
|
||||
"""CE hook executor.
|
||||
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
|
||||
|
||||
HookSkipped and HookSoftFailed are real classes kept here because
|
||||
process_message.py (CE code) uses isinstance checks against them.
|
||||
Usage (Celery tasks and FastAPI handlers):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
execute_hook is the public entry point. It dispatches to _execute_hook_impl
|
||||
via fetch_versioned_implementation so that:
|
||||
- CE: onyx.hooks.executor._execute_hook_impl → no-op, returns HookSkipped()
|
||||
- EE: ee.onyx.hooks.executor._execute_hook_impl → real HTTP call
|
||||
if isinstance(result, HookSkipped):
|
||||
# no active hook configured — continue with original behavior
|
||||
...
|
||||
elif isinstance(result, HookSoftFailed):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is a validated Pydantic model instance (response_type)
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
--------------------------
|
||||
``is_reachable`` on the Hook row is updated selectively — only when the outcome
|
||||
carries meaningful signal about physical reachability:
|
||||
|
||||
NetworkError (DNS, connection refused) → False (cannot reach the server)
|
||||
HTTP 401 / 403 → False (api_key revoked or invalid)
|
||||
TimeoutException → None (server may be slow, skip write)
|
||||
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
|
||||
Unknown exception → None (no signal, skip write)
|
||||
Non-JSON / non-dict response → None (server responded, skip write)
|
||||
Success (2xx, valid dict) → True (confirmed reachable)
|
||||
|
||||
None means "leave the current value unchanged" — no DB round-trip is made.
|
||||
|
||||
DB session design
|
||||
-----------------
|
||||
The executor uses three sessions:
|
||||
|
||||
1. Caller's session (db_session) — used only for the hook lookup read. All
|
||||
needed fields are extracted from the Hook object before the HTTP call, so
|
||||
the caller's session is not held open during the external HTTP request.
|
||||
|
||||
2. Log session — a separate short-lived session opened after the HTTP call
|
||||
completes to write the HookExecutionLog row on failure. Success runs are
|
||||
not recorded. Committed independently of everything else.
|
||||
|
||||
3. Reachable session — a second short-lived session to update is_reachable on
|
||||
the Hook. Kept separate from the log session so a concurrent hook deletion
|
||||
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
|
||||
prevent the execution log from being written. This update is best-effort.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.db.hook import create_hook_execution_log__no_commit
|
||||
from onyx.db.hook import get_non_deleted_hook_by_hook_point
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class HookSkipped:
|
||||
@@ -30,15 +87,277 @@ class HookSoftFailed:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def _execute_hook_impl(
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _HttpOutcome(BaseModel):
|
||||
"""Structured result of an HTTP hook call, returned by _process_response."""
|
||||
|
||||
is_success: bool
|
||||
updated_is_reachable: (
|
||||
bool | None
|
||||
) # True/False = write to DB, None = unchanged (skip write)
|
||||
status_code: int | None
|
||||
error_message: str | None
|
||||
response_payload: dict[str, Any] | None
|
||||
|
||||
|
||||
def _lookup_hook(
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
) -> Hook | HookSkipped:
|
||||
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
|
||||
|
||||
No HTTP call is made and no DB writes are performed for any HookSkipped path.
|
||||
There is nothing to log and no reachability information to update.
|
||||
"""
|
||||
if not HOOKS_AVAILABLE:
|
||||
return HookSkipped()
|
||||
hook = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if hook is None or not hook.is_active:
|
||||
return HookSkipped()
|
||||
if not hook.endpoint_url:
|
||||
return HookSkipped()
|
||||
return hook
|
||||
|
||||
|
||||
def _process_response(
|
||||
*,
|
||||
db_session: Session, # noqa: ARG001
|
||||
hook_point: HookPoint, # noqa: ARG001
|
||||
payload: dict[str, Any], # noqa: ARG001
|
||||
response_type: type[T], # noqa: ARG001
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""CE no-op — hooks are not available without EE."""
|
||||
return HookSkipped()
|
||||
response: httpx.Response | None,
|
||||
exc: Exception | None,
|
||||
timeout: float,
|
||||
) -> _HttpOutcome:
|
||||
"""Process the result of an HTTP call and return a structured outcome.
|
||||
|
||||
Called after the client.post() try/except. If post() raised, exc is set and
|
||||
response is None. Otherwise response is set and exc is None. Handles
|
||||
raise_for_status(), JSON decoding, and the dict shape check.
|
||||
"""
|
||||
if exc is not None:
|
||||
if isinstance(exc, httpx.NetworkError):
|
||||
msg = f"Hook network error (endpoint unreachable): {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False,
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
msg = f"Hook timed out after {timeout}s: {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # timeout doesn't indicate unreachability
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
msg = f"Hook call failed: {exc}"
|
||||
logger.exception(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # unknown error — don't make assumptions
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"exactly one of response or exc must be non-None; both are None"
|
||||
)
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
|
||||
# so the operator knows to update it. All other HTTP errors keep is_reachable
|
||||
# as-is (server is up, the request just failed for application reasons).
|
||||
auth_failed = e.response.status_code in (401, 403)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False if auth_failed else None,
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except (json.JSONDecodeError, httpx.DecodingError) as e:
|
||||
msg = f"Hook returned non-JSON response: {e}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
|
||||
logger.warning(msg)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
return _HttpOutcome(
|
||||
is_success=True,
|
||||
updated_is_reachable=True,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
response_payload=response_payload,
|
||||
)
|
||||
|
||||
|
||||
def _persist_result(
|
||||
*,
|
||||
hook_id: int,
|
||||
outcome: _HttpOutcome,
|
||||
duration_ms: int,
|
||||
) -> None:
|
||||
"""Write the execution log on failure and optionally update is_reachable, each
|
||||
in its own session so a failure in one does not affect the other."""
|
||||
# Only write the execution log on failure — success runs are not recorded.
|
||||
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
|
||||
# deleted between the initial lookup and here).
|
||||
if not outcome.is_success:
|
||||
try:
|
||||
with get_session_with_current_tenant() as log_session:
|
||||
create_hook_execution_log__no_commit(
|
||||
db_session=log_session,
|
||||
hook_id=hook_id,
|
||||
is_success=False,
|
||||
error_message=outcome.error_message,
|
||||
status_code=outcome.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
log_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to persist hook execution log for hook_id={hook_id}"
|
||||
)
|
||||
|
||||
# Update is_reachable separately — best-effort, non-critical.
|
||||
# None means the value is unchanged (set by the caller to skip the no-op write).
|
||||
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
|
||||
# concurrently deleted, so keep this isolated from the log write above.
|
||||
if outcome.updated_is_reachable is not None:
|
||||
try:
|
||||
with get_session_with_current_tenant() as reachable_session:
|
||||
update_hook__no_commit(
|
||||
db_session=reachable_session,
|
||||
hook_id=hook_id,
|
||||
is_reachable=outcome.updated_is_reachable,
|
||||
)
|
||||
reachable_session.commit()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _execute_hook_inner(
|
||||
hook: Hook,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSoftFailed:
|
||||
"""Make the HTTP call, validate the response, and return a typed model.
|
||||
|
||||
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
|
||||
"""
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
"active hooks without an endpoint_url must be rejected by _lookup_hook"
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
response: httpx.Response | None = None
|
||||
exc: Exception | None = None
|
||||
try:
|
||||
api_key: str | None = (
|
||||
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
)
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(
|
||||
timeout=timeout, follow_redirects=False
|
||||
) as client: # SSRF guard: never follow redirects
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
|
||||
# Validate the response payload against response_type.
|
||||
# A validation failure downgrades the outcome to a failure so it is logged,
|
||||
# is_reachable is left unchanged (server responded — just a bad payload),
|
||||
# and fail_strategy is respected below.
|
||||
validated_model: T | None = None
|
||||
if outcome.is_success and outcome.response_payload is not None:
|
||||
try:
|
||||
validated_model = response_type.model_validate(outcome.response_payload)
|
||||
except ValidationError as e:
|
||||
msg = (
|
||||
f"Hook response failed validation against {response_type.__name__}: {e}"
|
||||
)
|
||||
outcome = _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=outcome.status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
outcome = outcome.model_copy(update={"updated_is_reachable": None})
|
||||
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
|
||||
|
||||
if not outcome.is_success:
|
||||
if fail_strategy == HookFailStrategy.HARD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.HOOK_EXECUTION_FAILED,
|
||||
outcome.error_message or "Hook execution failed.",
|
||||
)
|
||||
logger.warning(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
|
||||
if validated_model is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"validated_model is None for successful hook call (hook_id={hook_id})",
|
||||
)
|
||||
return validated_model
|
||||
|
||||
|
||||
def execute_hook(
|
||||
@@ -48,15 +367,25 @@ def execute_hook(
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""Execute the hook for the given hook point.
|
||||
"""Execute the hook for the given hook point synchronously.
|
||||
|
||||
Dispatches to the versioned implementation so EE gets the real executor
|
||||
and CE gets the no-op stub, without any changes at the call site.
|
||||
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
|
||||
hook failed with SOFT fail strategy, or a validated response model on success.
|
||||
Raises OnyxError on HARD failure or if the hook is misconfigured.
|
||||
"""
|
||||
impl = fetch_versioned_implementation("onyx.hooks.executor", "_execute_hook_impl")
|
||||
return impl(
|
||||
db_session=db_session,
|
||||
hook_point=hook_point,
|
||||
payload=payload,
|
||||
response_type=response_type,
|
||||
)
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
fail_strategy = hook.fail_strategy
|
||||
hook_id = hook.id
|
||||
|
||||
try:
|
||||
return _execute_hook_inner(hook, payload, response_type)
|
||||
except Exception:
|
||||
if fail_strategy == HookFailStrategy.SOFT:
|
||||
logger.exception(
|
||||
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
raise
|
||||
|
||||
5
backend/onyx/hooks/utils.py
Normal file
5
backend/onyx/hooks/utils.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT
|
||||
@@ -8,6 +8,24 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMOverride(BaseModel):
|
||||
"""Per-request LLM settings that override persona defaults.
|
||||
|
||||
All fields are optional — only the fields that differ from the persona's
|
||||
configured LLM need to be supplied. Used both over the wire (API requests)
|
||||
and for multi-model comparison, where one override is supplied per model.
|
||||
|
||||
Attributes:
|
||||
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
|
||||
When ``None``, the persona's default provider is used.
|
||||
model_version: Specific model version string (e.g. ``"gpt-4o"``).
|
||||
When ``None``, the persona's default model is used.
|
||||
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
|
||||
persona's default temperature is used.
|
||||
display_name: Human-readable label shown in the UI for this model,
|
||||
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
|
||||
when not set.
|
||||
"""
|
||||
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
@@ -77,6 +77,7 @@ from onyx.server.features.default_assistant.api import (
|
||||
)
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.hierarchy.api import router as hierarchy_router
|
||||
from onyx.server.features.hooks.api import router as hook_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
admin_router as admin_input_prompt_router,
|
||||
)
|
||||
@@ -438,7 +439,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
@@ -454,6 +454,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
|
||||
register_onyx_exception_handlers(application)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, hook_router)
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
|
||||
@@ -76,18 +76,11 @@ class CategorizedFiles(BaseModel):
|
||||
acceptable: list[UploadFile] = Field(default_factory=list)
|
||||
rejected: list[RejectedFile] = Field(default_factory=list)
|
||||
acceptable_file_to_token_count: dict[str, int] = Field(default_factory=dict)
|
||||
# Filenames within `acceptable` that should be stored but not indexed.
|
||||
skip_indexing: set[str] = Field(default_factory=set)
|
||||
|
||||
# Allow FastAPI UploadFile instances
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
def _skip_token_threshold(extension: str) -> bool:
|
||||
"""Return True if this file extension should bypass the token limit."""
|
||||
return extension.lower() in OnyxFileExtensions.TABULAR_EXTENSIONS
|
||||
|
||||
|
||||
def _apply_long_side_cap(width: int, height: int, cap: int) -> tuple[int, int]:
|
||||
if max(width, height) <= cap:
|
||||
return width, height
|
||||
@@ -271,17 +264,7 @@ def categorize_uploaded_files(
|
||||
token_count = count_tokens(
|
||||
text_content, tokenizer, token_limit=token_threshold
|
||||
)
|
||||
exceeds_threshold = (
|
||||
token_threshold is not None and token_count > token_threshold
|
||||
)
|
||||
if exceeds_threshold and _skip_token_threshold(extension):
|
||||
# Exempt extensions (e.g. spreadsheets) are accepted
|
||||
# but flagged to skip indexing — only metadata is
|
||||
# injected into the LLM context.
|
||||
results.acceptable.append(upload)
|
||||
results.acceptable_file_to_token_count[filename] = token_count
|
||||
results.skip_indexing.add(filename)
|
||||
elif exceeds_threshold:
|
||||
if token_threshold is not None and token_count > token_threshold:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
|
||||
@@ -27,7 +27,6 @@ from onyx.auth.email_utils import send_user_email_invite
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -774,13 +773,6 @@ def _get_token_created_at(
|
||||
return get_current_token_creation_postgres(user, db_session)
|
||||
|
||||
|
||||
@router.get("/me/permissions", tags=PUBLIC_API_TAGS)
|
||||
def get_current_user_permissions(
|
||||
user: User = Depends(current_user),
|
||||
) -> list[str]:
|
||||
return sorted(p.value for p in get_effective_permissions(user))
|
||||
|
||||
|
||||
@router.get("/me", tags=PUBLIC_API_TAGS)
|
||||
def verify_user_logged_in(
|
||||
request: Request,
|
||||
|
||||
@@ -7,7 +7,6 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
@@ -42,7 +41,6 @@ class FullUserSnapshot(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
role: UserRole
|
||||
account_type: AccountType
|
||||
is_active: bool
|
||||
password_configured: bool
|
||||
personal_name: str | None
|
||||
@@ -62,7 +60,6 @@ class FullUserSnapshot(BaseModel):
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
account_type=user.account_type,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
personal_name=user.personal_name,
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.chat.chat_utils import extract_headers
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -570,6 +575,46 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
# Narrowed here; is_multi_model already checked llm_overrides is not None
|
||||
llm_overrides = chat_message_req.llm_overrides or []
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in handle_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
if is_multi_model and not chat_message_req.stream:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -660,6 +705,30 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
try:
|
||||
# Ownership check: get_chat_message raises ValueError if the message
|
||||
# doesn't belong to this user, preventing cross-user mutation.
|
||||
get_chat_message(
|
||||
chat_message_id=request_body.user_message_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
@@ -9,8 +9,8 @@ def mime_type_to_chat_file_type(mime_type: str | None) -> ChatFileType:
|
||||
if mime_type in OnyxMimeTypes.IMAGE_MIME_TYPES:
|
||||
return ChatFileType.IMAGE
|
||||
|
||||
if mime_type in OnyxMimeTypes.TABULAR_MIME_TYPES:
|
||||
return ChatFileType.TABULAR
|
||||
if mime_type in OnyxMimeTypes.CSV_MIME_TYPES:
|
||||
return ChatFileType.CSV
|
||||
|
||||
if mime_type in OnyxMimeTypes.DOCUMENT_MIME_TYPES:
|
||||
return ChatFileType.DOC
|
||||
|
||||
@@ -2,11 +2,25 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class Placement(BaseModel):
|
||||
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
|
||||
"""Coordinates that identify where a streaming packet belongs in the UI.
|
||||
|
||||
The frontend uses these fields to route each packet to the correct turn,
|
||||
tool tab, agent sub-turn, and (in multi-model mode) response column.
|
||||
|
||||
Attributes:
|
||||
turn_index: Monotonically increasing index of the iterative reasoning block
|
||||
(e.g. tool call round) within this chat message. Lower values happened first.
|
||||
tab_index: Disambiguates parallel tool calls within the same turn so each
|
||||
tool's output can be displayed in its own tab.
|
||||
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
|
||||
top-level packets; an integer for tool-within-tool output.
|
||||
model_index: Which model this packet belongs to. ``0`` for single-model
|
||||
responses; ``0``, ``1``, or ``2`` for multi-model comparison. ``None``
|
||||
for pre-LLM setup packets (e.g. message ID info) that are yielded
|
||||
before any Emitter runs.
|
||||
"""
|
||||
|
||||
turn_index: int
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
|
||||
model_index: int | None = None
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.db.notification import get_notifications
|
||||
from onyx.db.notification import update_notification_last_shown
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.features.build.utils import is_onyx_craft_enabled
|
||||
@@ -37,7 +38,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -98,7 +98,7 @@ def fetch_settings(
|
||||
needs_reindexing=needs_reindexing,
|
||||
onyx_craft_enabled=onyx_craft_enabled_for_user,
|
||||
vector_db_enabled=not DISABLE_VECTOR_DB,
|
||||
hooks_enabled=not MULTI_TENANT,
|
||||
hooks_enabled=HOOKS_AVAILABLE,
|
||||
version=onyx_version,
|
||||
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
default_user_file_max_upload_size_mb=min(
|
||||
|
||||
@@ -116,7 +116,7 @@ class UserSettings(Settings):
|
||||
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
|
||||
# document sets are unavailable.
|
||||
vector_db_enabled: bool = True
|
||||
# True when hooks are available: single-tenant EE deployments only.
|
||||
# True when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
hooks_enabled: bool = False
|
||||
# Application version, read from the ONYX_VERSION env var at startup.
|
||||
version: str | None = None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
@@ -708,7 +709,6 @@ def run_research_agent_calls(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from queue import Queue
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
@@ -744,8 +744,8 @@ if __name__ == "__main__":
|
||||
if user is None:
|
||||
raise ValueError("No users found in database. Please create a user first.")
|
||||
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
emitter_queue: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(merged_queue=emitter_queue)
|
||||
state_container = ChatStateContainer()
|
||||
|
||||
tool_dict = construct_tools(
|
||||
@@ -792,4 +792,4 @@ if __name__ == "__main__":
|
||||
print(result.intermediate_report)
|
||||
print("=" * 80)
|
||||
print(f"Citations: {result.citation_mapping}")
|
||||
print(f"Total packets emitted: {bus.qsize()}")
|
||||
print(f"Total packets emitted: {emitter_queue.qsize()}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import csv
|
||||
import json
|
||||
import queue
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from io import StringIO
|
||||
@@ -11,7 +12,6 @@ import requests
|
||||
from requests import JSONDecodeError
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -296,9 +296,9 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
|
||||
# Use default emitter if none provided
|
||||
# Use a discard emitter if none provided (packets go nowhere)
|
||||
if emitter is None:
|
||||
emitter = get_default_emitter()
|
||||
emitter = Emitter(merged_queue=queue.Queue())
|
||||
|
||||
return [
|
||||
CustomTool(
|
||||
@@ -367,7 +367,7 @@ if __name__ == "__main__":
|
||||
tools = build_custom_tools_from_openapi_schema_and_headers(
|
||||
tool_id=0, # dummy tool id
|
||||
openapi_schema=openapi_schema,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
dynamic_schema_info=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import io
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -10,7 +9,6 @@ from typing_extensions import override
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_chat_file_by_id
|
||||
@@ -171,13 +169,10 @@ class FileReaderTool(Tool[FileReaderToolOverrideKwargs]):
|
||||
|
||||
chat_file = self._load_file(file_id)
|
||||
|
||||
# Only PLAIN_TEXT and TABULAR are guaranteed to contain actual text bytes.
|
||||
# Only PLAIN_TEXT and CSV are guaranteed to contain actual text bytes.
|
||||
# DOC type in a loaded file means plaintext extraction failed and the
|
||||
# content is the original binary (e.g. raw PDF/DOCX bytes).
|
||||
if chat_file.file_type not in (
|
||||
ChatFileType.PLAIN_TEXT,
|
||||
ChatFileType.TABULAR,
|
||||
):
|
||||
if chat_file.file_type not in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV):
|
||||
raise ToolCallException(
|
||||
message=f"File {file_id} is not a text file (type={chat_file.file_type})",
|
||||
llm_facing_message=(
|
||||
@@ -186,19 +181,7 @@ class FileReaderTool(Tool[FileReaderToolOverrideKwargs]):
|
||||
)
|
||||
|
||||
try:
|
||||
if chat_file.file_type == ChatFileType.PLAIN_TEXT:
|
||||
full_text = chat_file.content.decode("utf-8", errors="replace")
|
||||
else:
|
||||
full_text = (
|
||||
extract_file_text(
|
||||
file=io.BytesIO(chat_file.content),
|
||||
file_name=chat_file.filename or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
or ""
|
||||
)
|
||||
except ToolCallException:
|
||||
raise
|
||||
full_text = chat_file.content.decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
raise ToolCallException(
|
||||
message=f"Failed to decode file {file_id}",
|
||||
|
||||
@@ -5,7 +5,6 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -28,9 +27,6 @@ INTERNAL_SEARCH_TOOL_NAME = "internal_search"
|
||||
INTERNAL_SEARCH_IN_CODE_TOOL_ID = "SearchTool"
|
||||
MAX_REQUEST_ATTEMPTS = 5
|
||||
RETRIABLE_STATUS_CODES = {429, 500, 502, 503, 504}
|
||||
QUESTION_TIMEOUT_SECONDS = 300
|
||||
QUESTION_RETRY_PAUSE_SECONDS = 30
|
||||
MAX_QUESTION_ATTEMPTS = 3
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -113,27 +109,6 @@ def normalize_api_base(api_base: str) -> str:
|
||||
return f"{normalized}/api"
|
||||
|
||||
|
||||
def load_completed_question_ids(output_file: Path) -> set[str]:
|
||||
if not output_file.exists():
|
||||
return set()
|
||||
|
||||
completed_ids: set[str] = set()
|
||||
with output_file.open("r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
try:
|
||||
record = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
question_id = record.get("question_id")
|
||||
if isinstance(question_id, str) and question_id:
|
||||
completed_ids.add(question_id)
|
||||
|
||||
return completed_ids
|
||||
|
||||
|
||||
def load_questions(questions_file: Path) -> list[QuestionRecord]:
|
||||
if not questions_file.exists():
|
||||
raise FileNotFoundError(f"Questions file not found: {questions_file}")
|
||||
@@ -373,7 +348,6 @@ async def generate_answers(
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
parallelism: int,
|
||||
skipped: int,
|
||||
) -> None:
|
||||
if parallelism < 1:
|
||||
raise ValueError("`--parallelism` must be at least 1.")
|
||||
@@ -408,178 +382,58 @@ async def generate_answers(
|
||||
write_lock = asyncio.Lock()
|
||||
completed = 0
|
||||
successful = 0
|
||||
stuck_count = 0
|
||||
failed_questions: list[FailedQuestionRecord] = []
|
||||
remaining_count = len(questions)
|
||||
overall_total = remaining_count + skipped
|
||||
question_durations: list[float] = []
|
||||
run_start_time = time.monotonic()
|
||||
|
||||
def print_progress() -> None:
|
||||
avg_time = (
|
||||
sum(question_durations) / len(question_durations)
|
||||
if question_durations
|
||||
else 0.0
|
||||
)
|
||||
elapsed = time.monotonic() - run_start_time
|
||||
eta = avg_time * (remaining_count - completed) / max(parallelism, 1)
|
||||
|
||||
done = skipped + completed
|
||||
bar_width = 30
|
||||
filled = (
|
||||
int(bar_width * done / overall_total)
|
||||
if overall_total
|
||||
else bar_width
|
||||
)
|
||||
bar = "█" * filled + "░" * (bar_width - filled)
|
||||
pct = (done / overall_total * 100) if overall_total else 100.0
|
||||
|
||||
parts = (
|
||||
f"\r{bar} {pct:5.1f}% "
|
||||
f"[{done}/{overall_total}] "
|
||||
f"avg {avg_time:.1f}s/q "
|
||||
f"elapsed {elapsed:.0f}s "
|
||||
f"ETA {eta:.0f}s "
|
||||
f"(ok:{successful} fail:{len(failed_questions)}"
|
||||
)
|
||||
if stuck_count:
|
||||
parts += f" stuck:{stuck_count}"
|
||||
if skipped:
|
||||
parts += f" skip:{skipped}"
|
||||
parts += ")"
|
||||
|
||||
sys.stderr.write(parts)
|
||||
sys.stderr.flush()
|
||||
|
||||
print_progress()
|
||||
total = len(questions)
|
||||
|
||||
async def process_question(question_record: QuestionRecord) -> None:
|
||||
nonlocal completed
|
||||
nonlocal successful
|
||||
nonlocal stuck_count
|
||||
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(1, MAX_QUESTION_ATTEMPTS + 1):
|
||||
q_start = time.monotonic()
|
||||
try:
|
||||
async with semaphore:
|
||||
result = await asyncio.wait_for(
|
||||
submit_question(
|
||||
session=session,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
internal_search_tool_id=internal_search_tool_id,
|
||||
question_record=question_record,
|
||||
),
|
||||
timeout=QUESTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
async with progress_lock:
|
||||
stuck_count += 1
|
||||
logger.warning(
|
||||
"Question %s timed out after %ss (attempt %s/%s, "
|
||||
"total stuck: %s) — retrying in %ss",
|
||||
question_record.question_id,
|
||||
QUESTION_TIMEOUT_SECONDS,
|
||||
attempt,
|
||||
MAX_QUESTION_ATTEMPTS,
|
||||
stuck_count,
|
||||
QUESTION_RETRY_PAUSE_SECONDS,
|
||||
)
|
||||
print_progress()
|
||||
last_error = TimeoutError(
|
||||
f"Timed out after {QUESTION_TIMEOUT_SECONDS}s "
|
||||
f"on attempt {attempt}/{MAX_QUESTION_ATTEMPTS}"
|
||||
try:
|
||||
async with semaphore:
|
||||
result = await submit_question(
|
||||
session=session,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
internal_search_tool_id=internal_search_tool_id,
|
||||
question_record=question_record,
|
||||
)
|
||||
await asyncio.sleep(QUESTION_RETRY_PAUSE_SECONDS)
|
||||
continue
|
||||
except Exception as exc:
|
||||
duration = time.monotonic() - q_start
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
question_durations.append(duration)
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
"Failed question %s (%s/%s)",
|
||||
question_record.question_id,
|
||||
completed,
|
||||
remaining_count,
|
||||
)
|
||||
print_progress()
|
||||
return
|
||||
|
||||
duration = time.monotonic() - q_start
|
||||
|
||||
async with write_lock:
|
||||
file.write(json.dumps(asdict(result), ensure_ascii=False))
|
||||
file.write("\n")
|
||||
file.flush()
|
||||
|
||||
except Exception as exc:
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
successful += 1
|
||||
question_durations.append(duration)
|
||||
print_progress()
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
"Failed question %s (%s/%s)",
|
||||
question_record.question_id,
|
||||
completed,
|
||||
total,
|
||||
)
|
||||
return
|
||||
|
||||
# All attempts exhausted due to timeouts
|
||||
async with write_lock:
|
||||
file.write(json.dumps(asdict(result), ensure_ascii=False))
|
||||
file.write("\n")
|
||||
file.flush()
|
||||
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(last_error),
|
||||
)
|
||||
)
|
||||
logger.error(
|
||||
"Question %s failed after %s timeout attempts (%s/%s)",
|
||||
question_record.question_id,
|
||||
MAX_QUESTION_ATTEMPTS,
|
||||
completed,
|
||||
remaining_count,
|
||||
)
|
||||
print_progress()
|
||||
successful += 1
|
||||
logger.info("Processed %s/%s questions", completed, total)
|
||||
|
||||
await asyncio.gather(
|
||||
*(process_question(question_record) for question_record in questions)
|
||||
)
|
||||
|
||||
# Final newline after progress bar
|
||||
sys.stderr.write("\n")
|
||||
sys.stderr.flush()
|
||||
|
||||
total_elapsed = time.monotonic() - run_start_time
|
||||
avg_time = (
|
||||
sum(question_durations) / len(question_durations)
|
||||
if question_durations
|
||||
else 0.0
|
||||
)
|
||||
stuck_suffix = f", {stuck_count} stuck timeouts" if stuck_count else ""
|
||||
resume_suffix = (
|
||||
f" — {skipped} previously completed, "
|
||||
f"{skipped + successful}/{overall_total} overall"
|
||||
if skipped
|
||||
else ""
|
||||
)
|
||||
logger.info(
|
||||
"Done: %s/%s successful in %.1fs (avg %.1fs/question%s)%s",
|
||||
successful,
|
||||
remaining_count,
|
||||
total_elapsed,
|
||||
avg_time,
|
||||
stuck_suffix,
|
||||
resume_suffix,
|
||||
)
|
||||
|
||||
if failed_questions:
|
||||
logger.warning(
|
||||
"%s questions failed:",
|
||||
"Completed with %s failed questions and %s successful questions.",
|
||||
len(failed_questions),
|
||||
successful,
|
||||
)
|
||||
for failed_question in failed_questions:
|
||||
logger.warning(
|
||||
@@ -599,30 +453,7 @@ def main() -> None:
|
||||
raise ValueError("`--max-questions` must be at least 1 when provided.")
|
||||
questions = questions[: args.max_questions]
|
||||
|
||||
completed_ids = load_completed_question_ids(args.output_file)
|
||||
logger.info(
|
||||
"Found %s already-answered question IDs in %s",
|
||||
len(completed_ids),
|
||||
args.output_file,
|
||||
)
|
||||
total_before_filter = len(questions)
|
||||
questions = [q for q in questions if q.question_id not in completed_ids]
|
||||
skipped = total_before_filter - len(questions)
|
||||
|
||||
if skipped:
|
||||
logger.info(
|
||||
"Resuming: %s/%s already answered, %s remaining",
|
||||
skipped,
|
||||
total_before_filter,
|
||||
len(questions),
|
||||
)
|
||||
else:
|
||||
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
|
||||
|
||||
if not questions:
|
||||
logger.info("All questions already answered. Nothing to do.")
|
||||
return
|
||||
|
||||
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
|
||||
logger.info("Writing answers to %s", args.output_file)
|
||||
|
||||
asyncio.run(
|
||||
@@ -632,7 +463,6 @@ def main() -> None:
|
||||
api_base=api_base,
|
||||
api_key=args.api_key,
|
||||
parallelism=args.parallelism,
|
||||
skipped=skipped,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -27,11 +27,13 @@ def create_placement(
|
||||
turn_index: int,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
model_index: int | None = 0,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
model_index=model_index,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -53,12 +52,7 @@ def tenant_context() -> Generator[None, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def create_test_user(
|
||||
db_session: Session,
|
||||
email_prefix: str,
|
||||
role: UserRole = UserRole.BASIC,
|
||||
account_type: AccountType = AccountType.STANDARD,
|
||||
) -> User:
|
||||
def create_test_user(db_session: Session, email_prefix: str) -> User:
|
||||
"""Helper to create a test user with a unique email"""
|
||||
# Use UUID to ensure unique email addresses
|
||||
unique_email = f"{email_prefix}_{uuid4().hex[:8]}@example.com"
|
||||
@@ -74,8 +68,7 @@ def create_test_user(
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=role,
|
||||
account_type=account_type,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -13,29 +13,16 @@ from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import PublicExternalUserGroup
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.models import UserRole
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
def _create_ext_perm_user(db_session: Session, name: str) -> User:
|
||||
"""Create an external-permission user for group sync tests."""
|
||||
return create_test_user(
|
||||
db_session,
|
||||
name,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
def _create_test_connector_credential_pair(
|
||||
db_session: Session, source: DocumentSource = DocumentSource.GOOGLE_DRIVE
|
||||
) -> ConnectorCredentialPair:
|
||||
@@ -113,9 +100,9 @@ class TestPerformExternalGroupSync:
|
||||
def test_initial_group_sync(self, db_session: Session) -> None:
|
||||
"""Test syncing external groups for the first time (initial sync)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user3 = _create_ext_perm_user(db_session, "user3")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user3 = create_test_user(db_session, "user3")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Mock external groups data as a generator that yields the expected groups
|
||||
@@ -188,9 +175,9 @@ class TestPerformExternalGroupSync:
|
||||
def test_update_existing_groups(self, db_session: Session) -> None:
|
||||
"""Test updating existing groups (adding/removing users)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user3 = _create_ext_perm_user(db_session, "user3")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user3 = create_test_user(db_session, "user3")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with original groups
|
||||
@@ -285,8 +272,8 @@ class TestPerformExternalGroupSync:
|
||||
def test_remove_groups(self, db_session: Session) -> None:
|
||||
"""Test removing groups (groups that no longer exist in external system)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with multiple groups
|
||||
@@ -370,7 +357,7 @@ class TestPerformExternalGroupSync:
|
||||
def test_empty_group_sync(self, db_session: Session) -> None:
|
||||
"""Test syncing when no groups are returned (all groups removed)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with groups
|
||||
@@ -426,7 +413,7 @@ class TestPerformExternalGroupSync:
|
||||
# Create many test users
|
||||
users = []
|
||||
for i in range(150): # More than the batch size of 100
|
||||
users.append(_create_ext_perm_user(db_session, f"user{i}"))
|
||||
users.append(create_test_user(db_session, f"user{i}"))
|
||||
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
@@ -465,8 +452,8 @@ class TestPerformExternalGroupSync:
|
||||
def test_mixed_regular_and_public_groups(self, db_session: Session) -> None:
|
||||
"""Test syncing a mix of regular and public groups"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
def mixed_group_sync_func(
|
||||
|
||||
@@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import BuildSessionStatus
|
||||
from onyx.db.models import BuildSession
|
||||
from onyx.db.models import User
|
||||
@@ -53,7 +52,6 @@ def test_user(db_session: Session, tenant_context: None) -> User: # noqa: ARG00
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
"""
|
||||
Tests that account_type is correctly set when creating users through
|
||||
the internal DB functions: add_slack_user_if_not_exists and
|
||||
batch_add_ext_perm_user_if_not_exists.
|
||||
|
||||
These functions are called by background workers (Slack bot, permission sync)
|
||||
and are not exposed via API endpoints, so they must be tested directly.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
|
||||
|
||||
def test_slack_user_creation_sets_account_type_bot(db_session: Session) -> None:
|
||||
"""add_slack_user_if_not_exists sets account_type=BOT and role=SLACK_USER."""
|
||||
user = add_slack_user_if_not_exists(db_session, "slack_acct_type@test.com")
|
||||
|
||||
assert user.role == UserRole.SLACK_USER
|
||||
assert user.account_type == AccountType.BOT
|
||||
|
||||
|
||||
def test_ext_perm_user_creation_sets_account_type(db_session: Session) -> None:
|
||||
"""batch_add_ext_perm_user_if_not_exists sets account_type=EXT_PERM_USER."""
|
||||
users = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session, ["extperm_acct_type@test.com"]
|
||||
)
|
||||
|
||||
assert len(users) == 1
|
||||
user = users[0]
|
||||
assert user.role == UserRole.EXT_PERM_USER
|
||||
assert user.account_type == AccountType.EXT_PERM_USER
|
||||
|
||||
|
||||
def test_ext_perm_to_slack_upgrade_updates_role_and_account_type(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""When an EXT_PERM_USER is upgraded to slack, both role and account_type update."""
|
||||
email = "ext_to_slack_acct_type@test.com"
|
||||
|
||||
# Create as ext_perm user first
|
||||
batch_add_ext_perm_user_if_not_exists(db_session, [email])
|
||||
|
||||
# Now "upgrade" via slack path
|
||||
user = add_slack_user_if_not_exists(db_session, email)
|
||||
|
||||
assert user.role == UserRole.SLACK_USER
|
||||
assert user.account_type == AccountType.BOT
|
||||
@@ -8,7 +8,6 @@ import pytest
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -47,7 +46,6 @@ def _create_admin(db_session: Session) -> User:
|
||||
is_superuser=True,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
account_type=AccountType.STANDARD,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -13,6 +13,7 @@ This test:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -20,7 +21,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
@@ -137,7 +138,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -200,7 +201,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -275,7 +276,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -350,7 +351,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -458,7 +459,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -541,7 +542,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
|
||||
@@ -8,6 +8,7 @@ Tests the priority logic for OAuth tokens when constructing custom tools:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
@@ -16,7 +17,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Persona
|
||||
@@ -174,7 +175,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -232,7 +233,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -284,7 +285,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -345,7 +346,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -416,7 +417,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -483,7 +484,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -536,7 +537,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
@@ -1175,7 +1175,7 @@ def test_code_interpreter_receives_chat_files(
|
||||
|
||||
file_descriptor: FileDescriptor = {
|
||||
"id": user_file.file_id,
|
||||
"type": ChatFileType.TABULAR,
|
||||
"type": ChatFileType.CSV,
|
||||
"name": "data.csv",
|
||||
"user_file_id": str(user_file.id),
|
||||
}
|
||||
|
||||
@@ -126,15 +126,6 @@ class UserManager:
|
||||
|
||||
return test_user
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(user: DATestUser) -> list[str]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/me/permissions",
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def is_role(
|
||||
user_to_verify: DATestUser,
|
||||
|
||||
@@ -104,30 +104,13 @@ class UserGroupManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(
|
||||
user_group: DATestUserGroup,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[str]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/permissions",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser,
|
||||
include_default: bool = False,
|
||||
) -> list[UserGroup]:
|
||||
params: dict[str, str] = {}
|
||||
if include_default:
|
||||
params["include_default"] = "true"
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
headers=user_performing_action.headers,
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [UserGroup(**ug) for ug in response.json()]
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -37,120 +33,3 @@ def test_limited(reset: None) -> None: # noqa: ARG001
|
||||
headers=api_key.headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def _get_service_account_account_type(
|
||||
admin_user: DATestUser,
|
||||
api_key_user_id: UUID,
|
||||
) -> AccountType:
|
||||
"""Fetch the account_type of a service account user via the user listing API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/users",
|
||||
headers=admin_user.headers,
|
||||
params={"include_api_keys": "true"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
user_id_str = str(api_key_user_id)
|
||||
for user in data["accepted"]:
|
||||
if user["id"] == user_id_str:
|
||||
return AccountType(user["account_type"])
|
||||
raise AssertionError(
|
||||
f"Service account user {user_id_str} not found in user listing"
|
||||
)
|
||||
|
||||
|
||||
def _get_default_group_user_ids(
|
||||
admin_user: DATestUser,
|
||||
) -> tuple[set[str], set[str]]:
|
||||
"""Return (admin_group_user_ids, basic_group_user_ids) from default groups."""
|
||||
all_groups = UserGroupManager.get_all(
|
||||
user_performing_action=admin_user,
|
||||
include_default=True,
|
||||
)
|
||||
admin_group = next(
|
||||
(g for g in all_groups if g.name == "Admin" and g.is_default), None
|
||||
)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.name == "Basic" and g.is_default), None
|
||||
)
|
||||
assert admin_group is not None, "Admin default group not found"
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
|
||||
admin_ids = {str(u.id) for u in admin_group.users}
|
||||
basic_ids = {str(u.id) for u in basic_group.users}
|
||||
return admin_ids, basic_ids
|
||||
|
||||
|
||||
def test_api_key_limited_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""LIMITED role API key: account_type is SERVICE_ACCOUNT, no group membership."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.LIMITED,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify no group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert (
|
||||
user_id_str not in admin_ids
|
||||
), "LIMITED API key should NOT be in Admin default group"
|
||||
assert (
|
||||
user_id_str not in basic_ids
|
||||
), "LIMITED API key should NOT be in Basic default group"
|
||||
|
||||
|
||||
def test_api_key_basic_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""BASIC role API key: account_type is SERVICE_ACCOUNT, in Basic group only."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.BASIC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify Basic group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert user_id_str in basic_ids, "BASIC API key should be in Basic default group"
|
||||
assert (
|
||||
user_id_str not in admin_ids
|
||||
), "BASIC API key should NOT be in Admin default group"
|
||||
|
||||
|
||||
def test_api_key_admin_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""ADMIN role API key: account_type is SERVICE_ACCOUNT, in Admin group only."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.ADMIN,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify Admin group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert user_id_str in admin_ids, "ADMIN API key should be in Admin default group"
|
||||
assert (
|
||||
user_id_str not in basic_ids
|
||||
), "ADMIN API key should NOT be in Basic default group"
|
||||
|
||||
@@ -4,10 +4,8 @@ import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@@ -97,63 +95,3 @@ def test_saml_user_conversion(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
# Verify the user's role was changed in the database
|
||||
assert UserManager.is_role(slack_user, UserRole.BASIC)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_user_conversion_sets_account_type_and_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that SAML login sets account_type to STANDARD when converting a
|
||||
non-web user (EXT_PERM_USER) and that the user receives the correct role
|
||||
(BASIC) after conversion.
|
||||
|
||||
This validates the permissions-migration-phase2 changes which ensure that:
|
||||
1. account_type is updated to 'standard' on SAML conversion
|
||||
2. The converted user is assigned to the Basic default group
|
||||
"""
|
||||
# Create an admin user (first user is automatically admin)
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# Create a user and set them as EXT_PERM_USER
|
||||
test_email = "ext_convert@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
|
||||
|
||||
# Simulate SAML login
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": test_email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_data = response.json()
|
||||
|
||||
# Verify account_type is set to standard after conversion
|
||||
assert (
|
||||
user_data["account_type"] == AccountType.STANDARD.value
|
||||
), f"Expected account_type='{AccountType.STANDARD.value}', got '{user_data['account_type']}'"
|
||||
|
||||
# Verify role is BASIC after conversion
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify the user was assigned to the Basic default group
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
basic_default = [g for g in all_groups if g.is_default and g.name == "Basic"]
|
||||
assert basic_default, "Basic default group not found"
|
||||
|
||||
basic_group = basic_default[0]
|
||||
member_emails = {u.email for u in basic_group.users}
|
||||
assert test_email in member_emails, (
|
||||
f"Converted user '{test_email}' not found in Basic default group members: "
|
||||
f"{member_emails}"
|
||||
)
|
||||
|
||||
@@ -35,16 +35,9 @@ from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
@@ -218,49 +211,6 @@ def test_create_user(scim_token: str, idp_style: str) -> None:
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_create_user_default_group_and_account_type(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""SCIM-provisioned users get Basic default group and STANDARD account_type."""
|
||||
email = f"scim_defaults_{idp_style}@example.com"
|
||||
ext_id = f"ext-defaults-{idp_style}"
|
||||
resp = _create_scim_user(scim_token, email, ext_id, idp_style)
|
||||
assert resp.status_code == 201
|
||||
user_id = resp.json()["id"]
|
||||
|
||||
# --- Verify group assignment via SCIM GET ---
|
||||
get_resp = ScimClient.get(f"/Users/{user_id}", scim_token)
|
||||
assert get_resp.status_code == 200
|
||||
groups = get_resp.json().get("groups", [])
|
||||
group_names = {g["display"] for g in groups}
|
||||
assert "Basic" in group_names, f"Expected 'Basic' in groups, got {group_names}"
|
||||
assert "Admin" not in group_names, "SCIM user should not be in Admin group"
|
||||
|
||||
# --- Verify account_type via admin API ---
|
||||
admin = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
page = UserManager.get_user_page(
|
||||
user_performing_action=admin,
|
||||
search_query=email,
|
||||
)
|
||||
assert page.total_items >= 1
|
||||
scim_user_snapshot = next((u for u in page.items if u.email == email), None)
|
||||
assert (
|
||||
scim_user_snapshot is not None
|
||||
), f"SCIM user {email} not found in user listing"
|
||||
assert (
|
||||
scim_user_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Expected STANDARD, got {scim_user_snapshot.account_type}"
|
||||
|
||||
|
||||
def test_get_user(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users/{id} returns the user resource with all stored fields."""
|
||||
email = f"scim_get_{idp_style}@example.com"
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
import mimetypes
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.file import FileManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
@@ -85,90 +79,3 @@ def test_send_message_with_text_file_attachment(admin_user: DATestUser) -> None:
|
||||
assert (
|
||||
"third line" in response.full_message.lower()
|
||||
), "Chat response should contain the contents of the file"
|
||||
|
||||
|
||||
def _set_token_threshold(admin_user: DATestUser, threshold_k: int) -> None:
|
||||
"""Set the file token count threshold via admin settings API."""
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/settings",
|
||||
json={"file_token_count_threshold_k": threshold_k},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def _upload_raw(
|
||||
filename: str,
|
||||
content: bytes,
|
||||
user: DATestUser,
|
||||
) -> dict[str, Any]:
|
||||
"""Upload a file and return the full JSON response (user_files + rejected_files)."""
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
headers = user.headers.copy()
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/user/projects/file/upload",
|
||||
files=[("files", (filename, content, mime_type or "application/octet-stream"))],
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def test_csv_over_token_threshold_uploaded_not_indexed(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""CSV exceeding token threshold is uploaded (accepted) but skips indexing."""
|
||||
_set_token_threshold(admin_user, threshold_k=1)
|
||||
try:
|
||||
# ~2000 tokens with default tokenizer, well over 1K threshold
|
||||
content = ("x " * 100 + "\n") * 20
|
||||
result = _upload_raw("large.csv", content.encode(), admin_user)
|
||||
|
||||
assert len(result["user_files"]) == 1, "CSV should be accepted"
|
||||
assert len(result["rejected_files"]) == 0, "CSV should not be rejected"
|
||||
assert (
|
||||
result["user_files"][0]["status"] == "SKIPPED"
|
||||
), "CSV over threshold should be SKIPPED (uploaded but not indexed)"
|
||||
assert (
|
||||
result["user_files"][0]["chunk_count"] is None
|
||||
), "Skipped file should have no chunks"
|
||||
finally:
|
||||
_set_token_threshold(admin_user, threshold_k=200)
|
||||
|
||||
|
||||
def test_csv_under_token_threshold_uploaded_and_indexed(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""CSV under token threshold is uploaded and queued for indexing."""
|
||||
_set_token_threshold(admin_user, threshold_k=200)
|
||||
try:
|
||||
content = "col1,col2\na,b\n"
|
||||
result = _upload_raw("small.csv", content.encode(), admin_user)
|
||||
|
||||
assert len(result["user_files"]) == 1, "CSV should be accepted"
|
||||
assert len(result["rejected_files"]) == 0, "CSV should not be rejected"
|
||||
assert (
|
||||
result["user_files"][0]["status"] == "PROCESSING"
|
||||
), "CSV under threshold should be PROCESSING (queued for indexing)"
|
||||
finally:
|
||||
_set_token_threshold(admin_user, threshold_k=200)
|
||||
|
||||
|
||||
def test_txt_over_token_threshold_rejected(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Non-exempt file exceeding token threshold is rejected entirely."""
|
||||
_set_token_threshold(admin_user, threshold_k=1)
|
||||
try:
|
||||
# ~2000 tokens, well over 1K threshold. Unlike CSV, .txt is not
|
||||
# exempt from the threshold so the file should be rejected.
|
||||
content = ("x " * 100 + "\n") * 20
|
||||
result = _upload_raw("big.txt", content.encode(), admin_user)
|
||||
|
||||
assert len(result["user_files"]) == 0, "File should not be accepted"
|
||||
assert len(result["rejected_files"]) == 1, "File should be rejected"
|
||||
assert "token limit" in result["rejected_files"][0]["reason"].lower()
|
||||
finally:
|
||||
_set_token_threshold(admin_user, threshold_k=200)
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import UserGroup as UserGroupModel
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_user_gets_permissions_when_added_to_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_for_perm_test")
|
||||
basic_user: DATestUser = UserManager.create(name="basic_user_for_perm_test")
|
||||
|
||||
# basic_user starts with only "basic" from the default group
|
||||
initial_permissions = UserManager.get_permissions(basic_user)
|
||||
assert "basic" in initial_permissions
|
||||
assert "add:agents" not in initial_permissions
|
||||
|
||||
# Create a new group and add basic_user
|
||||
group = UserGroupManager.create(
|
||||
name="perm-test-group",
|
||||
user_ids=[admin_user.id, basic_user.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Grant a non-basic permission to the group and recompute
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_group = db_session.get(UserGroupModel, group.id)
|
||||
assert db_group is not None
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=db_group.id,
|
||||
permission=Permission.ADD_AGENTS,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
recompute_user_permissions__no_commit(basic_user.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
# Verify the user gained the new permission (expanded includes read:agents)
|
||||
updated_permissions = UserManager.get_permissions(basic_user)
|
||||
assert (
|
||||
"add:agents" in updated_permissions
|
||||
), f"User should have 'add:agents' after group grant, got: {updated_permissions}"
|
||||
assert (
|
||||
"read:agents" in updated_permissions
|
||||
), f"User should have implied 'read:agents', got: {updated_permissions}"
|
||||
assert "basic" in updated_permissions
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_group_permission_change_propagates_to_all_members(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_propagate")
|
||||
user_a: DATestUser = UserManager.create(name="user_a_propagate")
|
||||
user_b: DATestUser = UserManager.create(name="user_b_propagate")
|
||||
|
||||
group = UserGroupManager.create(
|
||||
name="propagate-test-group",
|
||||
user_ids=[admin_user.id, user_a.id, user_b.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Neither user should have add:agents yet
|
||||
for u in (user_a, user_b):
|
||||
assert "add:agents" not in UserManager.get_permissions(u)
|
||||
|
||||
# Grant add:agents to the group, then batch-recompute
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
grant = PermissionGrant(
|
||||
group_id=group.id,
|
||||
permission=Permission.ADD_AGENTS,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
db_session.add(grant)
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
# Both users should now have the permission (plus implied read:agents)
|
||||
for u in (user_a, user_b):
|
||||
perms = UserManager.get_permissions(u)
|
||||
assert "add:agents" in perms, f"{u.id} missing add:agents: {perms}"
|
||||
assert "read:agents" in perms, f"{u.id} missing implied read:agents: {perms}"
|
||||
|
||||
# Soft-delete the grant and recompute — permission should be removed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_grant = (
|
||||
db_session.query(PermissionGrant)
|
||||
.filter_by(group_id=group.id, permission=Permission.ADD_AGENTS)
|
||||
.first()
|
||||
)
|
||||
assert db_grant is not None
|
||||
db_grant.is_deleted = True
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
for u in (user_a, user_b):
|
||||
perms = UserManager.get_permissions(u)
|
||||
assert "add:agents" not in perms, f"{u.id} still has add:agents: {perms}"
|
||||
@@ -1,30 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_new_group_gets_basic_permission(reset: None) -> None: # noqa: ARG001
|
||||
admin_user: DATestUser = UserManager.create(name="admin_for_basic_perm")
|
||||
|
||||
user_group = UserGroupManager.create(
|
||||
name="basic-perm-test-group",
|
||||
user_ids=[admin_user.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
permissions = UserGroupManager.get_permissions(
|
||||
user_group=user_group,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert (
|
||||
"basic" in permissions
|
||||
), f"New group should have 'basic' permission, got: {permissions}"
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Integration tests for default group assignment on user registration.
|
||||
|
||||
Verifies that:
|
||||
- The first registered user is assigned to the Admin default group
|
||||
- Subsequent registered users are assigned to the Basic default group
|
||||
- account_type is set to STANDARD for email/password registrations
|
||||
"""
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_default_group_assignment_on_registration(reset: None) -> None: # noqa: ARG001
|
||||
# Register first user — should become admin
|
||||
admin_user: DATestUser = UserManager.create(name="first_user")
|
||||
assert admin_user.role == UserRole.ADMIN
|
||||
|
||||
# Register second user — should become basic
|
||||
basic_user: DATestUser = UserManager.create(name="second_user")
|
||||
assert basic_user.role == UserRole.BASIC
|
||||
|
||||
# Fetch all groups including default ones
|
||||
all_groups = UserGroupManager.get_all(
|
||||
user_performing_action=admin_user,
|
||||
include_default=True,
|
||||
)
|
||||
|
||||
# Find the default Admin and Basic groups
|
||||
admin_group = next(
|
||||
(g for g in all_groups if g.name == "Admin" and g.is_default), None
|
||||
)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.name == "Basic" and g.is_default), None
|
||||
)
|
||||
assert admin_group is not None, "Admin default group not found"
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
|
||||
# Verify admin user is in Admin group and NOT in Basic group
|
||||
admin_group_user_ids = {str(u.id) for u in admin_group.users}
|
||||
basic_group_user_ids = {str(u.id) for u in basic_group.users}
|
||||
|
||||
assert (
|
||||
admin_user.id in admin_group_user_ids
|
||||
), "First user should be in Admin default group"
|
||||
assert (
|
||||
admin_user.id not in basic_group_user_ids
|
||||
), "First user should NOT be in Basic default group"
|
||||
|
||||
# Verify basic user is in Basic group and NOT in Admin group
|
||||
assert (
|
||||
basic_user.id in basic_group_user_ids
|
||||
), "Second user should be in Basic default group"
|
||||
assert (
|
||||
basic_user.id not in admin_group_user_ids
|
||||
), "Second user should NOT be in Admin default group"
|
||||
|
||||
# Verify account_type is STANDARD for both users via user listing API
|
||||
paginated_result = UserManager.get_user_page(
|
||||
user_performing_action=admin_user,
|
||||
page_num=0,
|
||||
page_size=10,
|
||||
)
|
||||
users_by_id = {str(u.id): u for u in paginated_result.items}
|
||||
|
||||
admin_snapshot = users_by_id.get(admin_user.id)
|
||||
basic_snapshot = users_by_id.get(basic_user.id)
|
||||
assert admin_snapshot is not None, "Admin user not found in user listing"
|
||||
assert basic_snapshot is not None, "Basic user not found in user listing"
|
||||
|
||||
assert (
|
||||
admin_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Admin user account_type should be STANDARD, got {admin_snapshot.account_type}"
|
||||
assert (
|
||||
basic_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Basic user account_type should be STANDARD, got {basic_snapshot.account_type}"
|
||||
@@ -1,176 +0,0 @@
|
||||
"""
|
||||
Unit tests for onyx.auth.permissions — pure logic and FastAPI dependency.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.permissions import ALL_PERMISSIONS
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.permissions import resolve_effective_permissions
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_effective_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveEffectivePermissions:
|
||||
def test_empty_set(self) -> None:
|
||||
assert resolve_effective_permissions(set()) == set()
|
||||
|
||||
def test_basic_no_implications(self) -> None:
|
||||
result = resolve_effective_permissions({"basic"})
|
||||
assert result == {"basic"}
|
||||
|
||||
def test_single_implication(self) -> None:
|
||||
result = resolve_effective_permissions({"add:agents"})
|
||||
assert result == {"add:agents", "read:agents"}
|
||||
|
||||
def test_manage_agents_implies_add_and_read(self) -> None:
|
||||
"""manage:agents directly maps to {add:agents, read:agents}."""
|
||||
result = resolve_effective_permissions({"manage:agents"})
|
||||
assert result == {"manage:agents", "add:agents", "read:agents"}
|
||||
|
||||
def test_manage_connectors_chain(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:connectors"})
|
||||
assert result == {"manage:connectors", "add:connectors", "read:connectors"}
|
||||
|
||||
def test_manage_document_sets(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:document_sets"})
|
||||
assert result == {
|
||||
"manage:document_sets",
|
||||
"read:document_sets",
|
||||
"read:connectors",
|
||||
}
|
||||
|
||||
def test_manage_user_groups_implies_all_reads(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:user_groups"})
|
||||
assert result == {
|
||||
"manage:user_groups",
|
||||
"read:connectors",
|
||||
"read:document_sets",
|
||||
"read:agents",
|
||||
"read:users",
|
||||
}
|
||||
|
||||
def test_admin_override(self) -> None:
|
||||
result = resolve_effective_permissions({"admin"})
|
||||
assert result == set(ALL_PERMISSIONS)
|
||||
|
||||
def test_admin_with_others(self) -> None:
|
||||
result = resolve_effective_permissions({"admin", "basic"})
|
||||
assert result == set(ALL_PERMISSIONS)
|
||||
|
||||
def test_multi_group_union(self) -> None:
|
||||
result = resolve_effective_permissions(
|
||||
{"add:agents", "manage:connectors", "basic"}
|
||||
)
|
||||
assert result == {
|
||||
"basic",
|
||||
"add:agents",
|
||||
"read:agents",
|
||||
"manage:connectors",
|
||||
"add:connectors",
|
||||
"read:connectors",
|
||||
}
|
||||
|
||||
def test_toggle_permission_no_implications(self) -> None:
|
||||
result = resolve_effective_permissions({"read:agent_analytics"})
|
||||
assert result == {"read:agent_analytics"}
|
||||
|
||||
def test_all_permissions_for_admin(self) -> None:
|
||||
result = resolve_effective_permissions({"admin"})
|
||||
assert len(result) == len(ALL_PERMISSIONS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_effective_permissions (expands implied at read time)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetEffectivePermissions:
|
||||
def test_expands_implied_permissions(self) -> None:
|
||||
"""Column stores only granted; get_effective_permissions expands implied."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["add:agents"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == {Permission.ADD_AGENTS, Permission.READ_AGENTS}
|
||||
|
||||
def test_admin_expands_to_all(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["admin"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == set(Permission)
|
||||
|
||||
def test_basic_stays_basic(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["basic"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == {Permission.BASIC_ACCESS}
|
||||
|
||||
def test_empty_column(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = []
|
||||
result = get_effective_permissions(user)
|
||||
assert result == set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# require_permission (FastAPI dependency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRequirePermission:
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_bypass(self) -> None:
|
||||
"""Admin stored in column should pass any permission check."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["admin"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_required_permission(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["manage:connectors"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implied_permission_passes(self) -> None:
|
||||
"""manage:connectors implies read:connectors at read time."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["manage:connectors"]
|
||||
|
||||
dep = require_permission(Permission.READ_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_permission_raises(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["basic"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await dep(user=user)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_permissions_fails(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = []
|
||||
|
||||
dep = require_permission(Permission.BASIC_ACCESS)
|
||||
with pytest.raises(OnyxError):
|
||||
await dep(user=user)
|
||||
@@ -1,29 +0,0 @@
|
||||
"""
|
||||
Unit tests for UserCreate schema dict methods.
|
||||
|
||||
Verifies that account_type is always included in create_update_dict
|
||||
and create_update_dict_superuser.
|
||||
"""
|
||||
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.db.enums import AccountType
|
||||
|
||||
|
||||
def test_create_update_dict_includes_default_account_type() -> None:
|
||||
uc = UserCreate(email="a@b.com", password="secret123")
|
||||
d = uc.create_update_dict()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
|
||||
|
||||
def test_create_update_dict_includes_explicit_account_type() -> None:
|
||||
uc = UserCreate(
|
||||
email="a@b.com", password="secret123", account_type=AccountType.SERVICE_ACCOUNT
|
||||
)
|
||||
d = uc.create_update_dict()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
|
||||
|
||||
def test_create_update_dict_superuser_includes_account_type() -> None:
|
||||
uc = UserCreate(email="a@b.com", password="secret123")
|
||||
d = uc.create_update_dict_superuser()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
173
backend/tests/unit/onyx/chat/test_emitter.py
Normal file
173
backend/tests/unit/onyx/chat/test_emitter.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Unit tests for the Emitter class.
|
||||
|
||||
All tests use the streaming mode (merged_queue required). Emitter has a single
|
||||
code path — no standalone bus.
|
||||
"""
|
||||
|
||||
import queue
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _placement(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
)
|
||||
|
||||
|
||||
def _packet(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Packet:
|
||||
"""Build a minimal valid packet with an OverallStop payload."""
|
||||
return Packet(
|
||||
placement=_placement(turn_index, tab_index, sub_turn_index),
|
||||
obj=OverallStop(stop_reason="test"),
|
||||
)
|
||||
|
||||
|
||||
def _make_emitter(model_idx: int = 0) -> tuple["Emitter", "queue.Queue"]:
|
||||
"""Return (emitter, queue) wired together."""
|
||||
mq: queue.Queue = queue.Queue()
|
||||
return Emitter(merged_queue=mq, model_idx=model_idx), mq
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queue routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterQueueRouting:
|
||||
def test_emit_lands_on_merged_queue(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet())
|
||||
assert not mq.empty()
|
||||
|
||||
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=1)
|
||||
emitter.emit(_packet())
|
||||
item = mq.get_nowait()
|
||||
assert isinstance(item, tuple)
|
||||
assert len(item) == 2
|
||||
|
||||
def test_multiple_packets_delivered_fifo(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
p1 = _packet(turn_index=0)
|
||||
p2 = _packet(turn_index=1)
|
||||
emitter.emit(p1)
|
||||
emitter.emit(p2)
|
||||
_, t1 = mq.get_nowait()
|
||||
_, t2 = mq.get_nowait()
|
||||
assert t1.placement.turn_index == 0
|
||||
assert t2.placement.turn_index == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# model_index tagging
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterModelIndexTagging:
|
||||
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
|
||||
"""N=1: default model_idx=0, so packet gets model_index=0."""
|
||||
emitter, mq = _make_emitter(model_idx=0)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 0
|
||||
|
||||
def test_model_idx_one_tags_packet(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=1)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 1
|
||||
|
||||
def test_model_idx_two_tags_packet(self) -> None:
|
||||
"""Boundary: third model in a 3-model run."""
|
||||
emitter, mq = _make_emitter(model_idx=2)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queue key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterQueueKey:
|
||||
def test_key_equals_model_idx(self) -> None:
|
||||
"""Drain loop uses the key to route packets; it must match model_idx."""
|
||||
emitter, mq = _make_emitter(model_idx=2)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 2
|
||||
|
||||
def test_n1_key_is_zero(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=0)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Placement field preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterPlacementPreservation:
|
||||
def test_turn_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(turn_index=5))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.turn_index == 5
|
||||
|
||||
def test_tab_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(tab_index=3))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.tab_index == 3
|
||||
|
||||
def test_sub_turn_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(sub_turn_index=2))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index == 2
|
||||
|
||||
def test_sub_turn_index_none_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(sub_turn_index=None))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index is None
|
||||
|
||||
def test_packet_obj_is_not_modified(self) -> None:
|
||||
"""The payload object must survive tagging untouched."""
|
||||
emitter, mq = _make_emitter()
|
||||
original_obj = OverallStop(stop_reason="sentinel")
|
||||
pkt = Packet(placement=_placement(), obj=original_obj)
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.obj is original_obj
|
||||
|
||||
def test_different_obj_types_are_handled(self) -> None:
|
||||
"""Any valid PacketObj type passes through correctly."""
|
||||
emitter, mq = _make_emitter()
|
||||
pkt = Packet(placement=_placement(), obj=ReasoningStart())
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert isinstance(tagged.obj, ReasoningStart)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user