Compare commits

..

76 Commits

Author SHA1 Message Date
SubashMohan
508c10daaf fix(user-create): ensure account_type is always set to STANDARD in create_update_dict 2026-04-01 14:29:48 +05:30
SubashMohan
c0003f0b77 fix(user-create): enforce STANDARD account type for self-registration 2026-04-01 13:04:55 +05:30
SubashMohan
dbf202ba05 fix(users): move user group assignment inside tenant context for proper schema targeting 2026-04-01 13:04:55 +05:30
SubashMohan
4987c9581f feat(enums): update GrantSource values to uppercase for consistency 2026-04-01 13:04:55 +05:30
SubashMohan
6a69347bd6 fix(migration): remove check for deleted permissions in basic grant logic 2026-04-01 13:04:55 +05:30
SubashMohan
b8cd332a36 refactor(users): remove user assignment to default groups after account upgrade 2026-04-01 13:04:55 +05:30
SubashMohan
e319728544 feat(permissions): enhance permission recomputation logic to exclude deleted grants 2026-04-01 13:04:55 +05:30
SubashMohan
b42a7858ed fix(docs): clarify user permissions aggregation in effective_permissions migration 2026-04-01 13:04:55 +05:30
SubashMohan
44e2fccbef feat(enums): update AccountType values to uppercase for consistency 2026-04-01 13:04:55 +05:30
SubashMohan
fbf0effcdf fix(permissions): ensure valid user IDs are processed in group permission recomputation 2026-04-01 13:04:55 +05:30
SubashMohan
a71c18c454 refactor(permissions): rename IMPLIES to IMPLIED_PERMISSIONS for clarity 2026-04-01 13:04:55 +05:30
SubashMohan
ccc813e075 feat(permissions): enhance effective permissions logic with user group and permission grant tables 2026-04-01 13:04:55 +05:30
SubashMohan
6eeeda2ab8 feat(permissions): refactor permission recomputation logic and introduce no-commit variants 2026-04-01 13:04:55 +05:30
SubashMohan
7cc4bf2286 feat(permissions): add endpoint to retrieve current user permissions and update permissions on user group changes 2026-04-01 13:04:55 +05:30
SubashMohan
ab0eeb5585 feat(permissions): grant basic permission to new user groups and add API endpoint for permission retrieval 2026-04-01 13:04:55 +05:30
SubashMohan
12b0b01787 feat(permissions): add effective_permissions JSONB column and related logic for permission resolution 2026-04-01 13:04:55 +05:30
SubashMohan
cb11a5c472 feat(notification): add USER_GROUP_ASSIGNMENT_FAILED type and improve notification query logic 2026-04-01 13:04:55 +05:30
SubashMohan
ee80b91b20 feat(migration): add error handling for missing default groups in user assignment
feat(users): update account type for external permissioned users to BOT
fix(tests): update tests to reflect changes in default group assignment behavior
feat(types): introduce AccountType enum for better user role management
2026-04-01 13:04:55 +05:30
SubashMohan
3a643067b9 fix(tests): comment out group visibility checks pending permission implementation 2026-04-01 13:04:55 +05:30
SubashMohan
c4c1de8d87 feat(migration): enhance downgrade logic to prevent FK violations and improve user creation tests with account type handling 2026-04-01 13:04:55 +05:30
SubashMohan
346de4fb39 fix(users): update permission handling in default group assignment logic 2026-04-01 13:04:55 +05:30
SubashMohan
4c08d1730f feat(users): add account_type to user creation and assign to default groups 2026-04-01 13:04:55 +05:30
SubashMohan
cfc2881b97 feat(migration): assign existing users to default groups
Add migration to assign all existing users to Admin/Basic default groups
based on their role and account_type. Add integration tests for group
assignment across all user creation paths (API keys, SAML, SCIM, registration).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-01 13:04:55 +05:30
SubashMohan
3c038165bb feat(users): backfill account_type and assign new users to default groups
Add migration to backfill account_type column (NOT NULL, default STANDARD)
based on existing user roles. Update all user creation paths (OAuth, SAML,
SCIM, API keys, Slack, anonymous) to set account_type and assign users to
appropriate default groups. Add comprehensive unit tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-01 13:04:55 +05:30
SubashMohan
04bdfe4749 feat(groups): seed default Admin and Basic user groups
Add migration to create system default user groups (Admin, Basic) with
permission grants. Update API and frontend to support is_default flag,
protect default groups from rename/delete, and add include_default
query parameter for group listing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-01 13:04:55 +05:30
Danelegend
de0f42f6cc refactor(files): Port csv type to tabular (#9785) 2026-04-01 03:37:13 +00:00
Raunak Bhagat
7ecefdc90f refactor(opal): split Card sizeVariant into padding + rounding (#9823) 2026-04-01 03:32:08 +00:00
Danelegend
21fc013893 feat(file-upload): Upload files exceeding tokens but skip indexing (#9751) 2026-04-01 02:14:51 +00:00
Justin Tahara
a1c3a68ba4 fix(perf): optimize chat sessions query to prevent DB cascading failures (#9802) 2026-04-01 01:28:37 +00:00
Evan Lohn
4fb175ae65 fix: install early exit (#9818) 2026-04-01 01:09:05 +00:00
Evan Lohn
800ad326df fix: discord token validation (#9817) 2026-04-01 01:08:38 +00:00
Bo-Onyx
6b920e8a3e feat(hook): refactor under ee (#9776) 2026-04-01 01:07:55 +00:00
Justin Tahara
ef3760796d feat(rds): Adding IO Metrics Alarms (#9789) 2026-04-01 01:07:45 +00:00
Jessica Singh
fa5b90df92 fix(connectors): fix reindex on paused file connectors (#9812) 2026-03-31 23:10:09 +00:00
Evan Lohn
53953ac4fa chore: fix indexing log2 (#9811) 2026-03-31 21:02:54 +00:00
Yuhong Sun
26bb5c990c chore: Rag script for benchmark/regression (#9781) 2026-03-31 20:46:17 +00:00
Evan Lohn
27b4ed301f chore: fix batch logging (#9808) 2026-03-31 20:10:33 +00:00
Jessica Singh
93ec270ccc feat(voice): VAD auto-stop only when auto-send is enabled (#9809) 2026-03-31 19:31:31 +00:00
Raunak Bhagat
9e2d6c8a1d refactor(admin): code-interpreter (#9790) 2026-03-31 19:08:55 +00:00
Nikolas Garza
fc934214d0 perf(swr): add SWR_KEYS registry and skip revalidation for stable hooks (#9695) 2026-03-31 19:07:42 +00:00
Raunak Bhagat
48fc45a0cd refactor(admin): web-search (#9761) 2026-03-31 19:04:18 +00:00
Jessica Singh
009266e53e fix(llm): when multiple providers are same type ensure name is prioritized when default (#9777) 2026-03-31 19:03:38 +00:00
Raunak Bhagat
ffb9df7308 refactor(admin): LLM Config (#9806) 2026-03-31 19:03:17 +00:00
Raunak Bhagat
b0f5e0b8d9 refactor(admin): image-generation (#9769) 2026-03-31 18:13:23 +00:00
acaprau
43aea5d614 chore(opensearch): Add Grafana dashboard for retrieval (#9657)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-31 16:56:40 +00:00
Bo-Onyx
593d82f431 feat(hook): hook status and logs (#9770) 2026-03-31 16:10:12 +00:00
Ben Wu
adf5691b5f feat(canvas 2/4): Canvas Connector data fetching (#9386) 2026-03-31 03:07:05 +00:00
Nikolas Garza
c1a8a5bd83 fix(tenants): run migrations on pool tenants before assigning to new users (#9788) 2026-03-31 01:24:01 +00:00
Justin Tahara
8fd486da99 feat(rds): Add Freeable Memory alert (#9787) 2026-03-31 00:59:30 +00:00
Raunak Bhagat
4bda4d3637 refactor: migrate away from cards/Select (#9771) 2026-03-31 00:27:01 +00:00
Justin Tahara
13c25eadad feat(rds): Adding CPU Alerts (#9784) 2026-03-31 00:22:15 +00:00
Justin Tahara
1f244e6388 feat(eks): Adding Cloudwatch logging (#9783) 2026-03-30 23:52:44 +00:00
Nikolas Garza
18b0416d30 feat(sentry): enable frontend source map uploads in cloud CI (#9775) 2026-03-30 23:42:57 +00:00
Nikolas Garza
4bc0bc1efb feat(helm): add Grafana dashboard provisioning (#9725) 2026-03-30 23:42:32 +00:00
Justin Tahara
1555217061 feat(rds): Adding RDS Snapshosts (#9779) 2026-03-30 23:17:08 +00:00
Nikolas Garza
d177a833f0 feat(sentry): add release tracking to backend and frontend (#9773) 2026-03-30 22:35:38 +00:00
Jamison Lahman
086997d3c5 chore(types): fix IconButton size props (#9772) 2026-03-30 21:40:25 +00:00
dependabot[bot]
dccec78397 chore(deps): bump helm/chart-testing-action from b5eebdd9998021f29756c53432f48dab66394810 to 2e2940618cb426dce2999631d543b53cdcfc8527 (#9764)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 14:41:01 -07:00
Jamison Lahman
0123133621 chore(fe): polish Query History table (#9767) 2026-03-30 21:30:13 +00:00
dependabot[bot]
0b9d154a73 chore(deps): bump runs-on/cache from 50350ad4242587b6c8c2baa2e740b1bc11285ff4 to a5f51d6f3fece787d03b7b4e981c82538a0654ed (#9763)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 13:54:43 -07:00
dependabot[bot]
6e65e55bf5 chore(deps): bump actions/cache from 5.0.3 to 5.0.4 (#9765)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 13:46:53 -07:00
Raunak Bhagat
3f9e208759 feat(opal): SelectCard + CardHeaderLayout (#9760) 2026-03-30 19:54:54 +00:00
dependabot[bot]
fb8edda14a chore(deps): bump pygments from 2.19.2 to 2.20.0 (#9757)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-30 18:30:18 +00:00
Jamison Lahman
58decd8a6b chore(gha): prefer ci-protected env (#9728) 2026-03-30 17:20:54 +00:00
Danelegend
e97204d9cc feat(indexing): Batch chunks during doc processing (#9468) 2026-03-30 11:49:36 +00:00
Danelegend
44ab02c94f refactor(indexing): Refactor indexing vector db abstraction (#9653) 2026-03-30 09:57:16 +00:00
Danelegend
a98cc30f25 refactor(indexing): Change adapters to support iterables (#9469) 2026-03-30 01:43:10 +00:00
Danelegend
a709dcb8fa feat(indexing): Max chunk processing (#9400) 2026-03-30 00:10:24 +00:00
Raunak Bhagat
a3dfe6aa1b refactor(opal): unify Interactive color system (#9717) 2026-03-28 00:40:23 +00:00
Nikolas Garza
23e4d55fb1 perf(swr): convert raw-fetch hooks to SWR to eliminate duplicate requests (#9694) 2026-03-28 00:26:20 +00:00
Jamison Lahman
470cc85f83 feat(cli): onyx-cli serve over SSH (#9726) 2026-03-27 23:46:14 +00:00
Justin Tahara
64d9be5a41 fix(openpyxl): Colors must be aRGB hex values (#9727) 2026-03-27 23:14:36 +00:00
roshan
71a5b469b0 feat(widget): add citation badges to chat widget (#9714) 2026-03-27 22:39:46 +00:00
Evan Lohn
462eb0697f fix: Anthropic litellm thinking workaround (#9713) 2026-03-27 21:03:05 +00:00
dependabot[bot]
b708dc8796 chore(deps): bump langchain-core from 1.2.11 to 1.2.22 (#9720)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-27 20:50:19 +00:00
dependabot[bot]
c9e2c32f55 chore(deps): bump cryptography from 46.0.5 to 46.0.6 (#9721)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-27 20:48:59 +00:00
309 changed files with 14726 additions and 4195 deletions

View File

@@ -704,6 +704,9 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
SENTRY_RELEASE=${{ github.sha }}
secrets: |
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
@@ -786,6 +789,9 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
SENTRY_RELEASE=${{ github.sha }}
secrets: |
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest

View File

@@ -35,6 +35,7 @@ jobs:
needs: [provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 5
steps:
- name: Checkout

View File

@@ -183,6 +183,7 @@ jobs:
- cherry-pick-to-latest-release
if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success' && needs.cherry-pick-to-latest-release.result == 'success'
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 10
steps:
- name: Checkout
@@ -232,6 +233,7 @@ jobs:
- cherry-pick-to-latest-release
if: always() && needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && (needs.resolve-cherry-pick-request.result == 'failure' || needs.cherry-pick-to-latest-release.result == 'failure')
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 10
steps:
- name: Checkout

View File

@@ -63,7 +63,7 @@ jobs:
targets: ${{ matrix.target }}
- name: Cache Cargo registry and build
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # zizmor: ignore[cache-poisoning]
uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # zizmor: ignore[cache-poisoning]
with:
path: |
~/.cargo/bin/

View File

@@ -41,7 +41,7 @@ jobs:
version: v3.19.0
- name: Set up chart-testing
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
uses: helm/chart-testing-action@2e2940618cb426dce2999631d543b53cdcfc8527
with:
uv_version: "0.9.9"

View File

@@ -284,7 +284,7 @@ jobs:
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
@@ -626,7 +626,7 @@ jobs:
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}

View File

@@ -56,7 +56,7 @@ jobs:
- name: Cache mypy cache
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
with:
path: .mypy_cache
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}

View File

@@ -31,6 +31,7 @@ jobs:
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-model-check"
- "extras=ecr-cache"
environment: ci-protected
timeout-minutes: 45
env:

View File

@@ -15,6 +15,7 @@ permissions:
jobs:
Deploy-Preview:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 30
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd

View File

@@ -25,6 +25,7 @@ permissions:
jobs:
Deploy-Storybook:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 30
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
@@ -54,6 +55,7 @@ jobs:
needs: Deploy-Storybook
if: always() && needs.Deploy-Storybook.result == 'failure'
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 10
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4

View File

@@ -9,6 +9,7 @@ on:
jobs:
sync-foss:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 45
permissions:
contents: read

View File

@@ -11,6 +11,7 @@ permissions:
jobs:
create-and-push-tag:
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 45
steps:

View File

@@ -0,0 +1,108 @@
"""backfill_account_type
Revision ID: 03d085c5c38d
Revises: 977e834c1427
Create Date: 2026-03-25 16:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "03d085c5c38d"
down_revision = "977e834c1427"
branch_labels = None
depends_on = None
_STANDARD = "STANDARD"
_BOT = "BOT"
_EXT_PERM_USER = "EXT_PERM_USER"
_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
_ANONYMOUS = "ANONYMOUS"
# Well-known anonymous user UUID
ANONYMOUS_USER_ID = "00000000-0000-0000-0000-000000000002"
# Email pattern for API key virtual users
API_KEY_EMAIL_PATTERN = r"API\_KEY\_\_%"
# Reflect the table structure for use in DML
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("email", sa.String),
sa.column("role", sa.String),
sa.column("account_type", sa.String),
)
def upgrade() -> None:
# ------------------------------------------------------------------
# Step 1: Backfill account_type from role.
# Order matters — most-specific matches first so the final catch-all
# only touches rows that haven't been classified yet.
# ------------------------------------------------------------------
# 1a. API key virtual users → SERVICE_ACCOUNT
op.execute(
sa.update(user_table)
.where(
user_table.c.email.ilike(API_KEY_EMAIL_PATTERN),
user_table.c.account_type.is_(None),
)
.values(account_type=_SERVICE_ACCOUNT)
)
# 1b. Anonymous user → ANONYMOUS
op.execute(
sa.update(user_table)
.where(
user_table.c.id == ANONYMOUS_USER_ID,
user_table.c.account_type.is_(None),
)
.values(account_type=_ANONYMOUS)
)
# 1c. SLACK_USER role → BOT
op.execute(
sa.update(user_table)
.where(
user_table.c.role == "SLACK_USER",
user_table.c.account_type.is_(None),
)
.values(account_type=_BOT)
)
# 1d. EXT_PERM_USER role → EXT_PERM_USER
op.execute(
sa.update(user_table)
.where(
user_table.c.role == "EXT_PERM_USER",
user_table.c.account_type.is_(None),
)
.values(account_type=_EXT_PERM_USER)
)
# 1e. Everything else → STANDARD
op.execute(
sa.update(user_table)
.where(user_table.c.account_type.is_(None))
.values(account_type=_STANDARD)
)
# ------------------------------------------------------------------
# Step 2: Set account_type to NOT NULL now that every row is filled.
# ------------------------------------------------------------------
op.alter_column(
"user",
"account_type",
nullable=False,
server_default="STANDARD",
)
def downgrade() -> None:
op.alter_column("user", "account_type", nullable=True, server_default=None)
op.execute(sa.update(user_table).values(account_type=None))

View File

@@ -0,0 +1,104 @@
"""add_effective_permissions
Adds a JSONB column `effective_permissions` to the user table to store
directly granted permissions (e.g. ["admin"] or ["basic"]). Implied
permissions are expanded at read time, not stored.
Backfill: joins user__user_group → permission_grant to collect each
user's granted permissions into a JSON array. Users without group
memberships keep the default [].
Revision ID: 503883791c39
Revises: b4b7e1028dfd
Create Date: 2026-03-30 14:49:22.261748
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "503883791c39"
down_revision = "b4b7e1028dfd"
branch_labels: str | None = None
depends_on: str | Sequence[str] | None = None
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("effective_permissions", postgresql.JSONB),
)
user_user_group = sa.table(
"user__user_group",
sa.column("user_id", sa.Uuid),
sa.column("user_group_id", sa.Integer),
)
permission_grant = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("is_deleted", sa.Boolean),
)
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"effective_permissions",
postgresql.JSONB(),
nullable=False,
server_default=sa.text("'[]'::jsonb"),
),
)
conn = op.get_bind()
# Deduplicated permissions per user
deduped = (
sa.select(
user_user_group.c.user_id,
permission_grant.c.permission,
)
.select_from(
user_user_group.join(
permission_grant,
sa.and_(
permission_grant.c.group_id == user_user_group.c.user_group_id,
permission_grant.c.is_deleted == sa.false(),
),
)
)
.distinct()
.subquery("deduped")
)
# Aggregate into JSONB array per user (order is not guaranteed;
# consumers read this as a set so ordering does not matter)
perms_per_user = (
sa.select(
deduped.c.user_id,
sa.func.jsonb_agg(
deduped.c.permission,
type_=postgresql.JSONB,
).label("perms"),
)
.group_by(deduped.c.user_id)
.subquery("sub")
)
conn.execute(
user_table.update()
.where(user_table.c.id == perms_per_user.c.user_id)
.values(effective_permissions=perms_per_user.c.perms)
)
def downgrade() -> None:
op.drop_column("user", "effective_permissions")

View File

@@ -0,0 +1,54 @@
"""csv to tabular chat file type
Revision ID: 8188861f4e92
Revises: d8cdfee5df80
Create Date: 2026-03-31 19:23:05.753184
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "8188861f4e92"
down_revision = "d8cdfee5df80"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE chat_message
SET files = (
SELECT jsonb_agg(
CASE
WHEN elem->>'type' = 'csv'
THEN jsonb_set(elem, '{type}', '"tabular"')
ELSE elem
END
)
FROM jsonb_array_elements(files) AS elem
)
WHERE files::text LIKE '%"type": "csv"%'
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE chat_message
SET files = (
SELECT jsonb_agg(
CASE
WHEN elem->>'type' = 'tabular'
THEN jsonb_set(elem, '{type}', '"csv"')
ELSE elem
END
)
FROM jsonb_array_elements(files) AS elem
)
WHERE files::text LIKE '%"type": "tabular"%'
"""
)

View File

@@ -0,0 +1,136 @@
"""seed_default_groups
Revision ID: 977e834c1427
Revises: 8188861f4e92
Create Date: 2026-03-25 14:59:41.313091
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import insert as pg_insert
# revision identifiers, used by Alembic.
revision = "977e834c1427"
down_revision = "8188861f4e92"
branch_labels = None
depends_on = None
# (group_name, permission_value)
DEFAULT_GROUPS = [
("Admin", "admin"),
("Basic", "basic"),
]
CUSTOM_SUFFIX = "(Custom)"
MAX_RENAME_ATTEMPTS = 100
# Reflect table structures for use in DML
user_group_table = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("name", sa.String),
sa.column("is_up_to_date", sa.Boolean),
sa.column("is_up_for_deletion", sa.Boolean),
sa.column("is_default", sa.Boolean),
)
permission_grant_table = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("grant_source", sa.String),
)
user__user_group_table = sa.table(
"user__user_group",
sa.column("user_group_id", sa.Integer),
sa.column("user_id", sa.Uuid),
)
def _find_available_name(conn: sa.engine.Connection, base: str) -> str:
"""Return a name like 'Admin (Custom)' or 'Admin (Custom 2)' that is not taken."""
candidate = f"{base} {CUSTOM_SUFFIX}"
attempt = 1
while attempt <= MAX_RENAME_ATTEMPTS:
exists = conn.execute(
sa.select(sa.literal(1))
.select_from(user_group_table)
.where(user_group_table.c.name == candidate)
.limit(1)
).fetchone()
if exists is None:
return candidate
attempt += 1
candidate = f"{base} (Custom {attempt})"
raise RuntimeError(
f"Could not find an available name for group '{base}' "
f"after {MAX_RENAME_ATTEMPTS} attempts"
)
def upgrade() -> None:
conn = op.get_bind()
for group_name, permission_value in DEFAULT_GROUPS:
# Step 1: Rename ALL existing groups that clash with the canonical name.
conflicting = conn.execute(
sa.select(user_group_table.c.id, user_group_table.c.name).where(
user_group_table.c.name == group_name
)
).fetchall()
for row_id, row_name in conflicting:
new_name = _find_available_name(conn, row_name)
op.execute(
sa.update(user_group_table)
.where(user_group_table.c.id == row_id)
.values(name=new_name, is_up_to_date=False)
)
# Step 2: Create a fresh default group.
result = conn.execute(
user_group_table.insert()
.values(
name=group_name,
is_up_to_date=True,
is_up_for_deletion=False,
is_default=True,
)
.returning(user_group_table.c.id)
).fetchone()
assert result is not None
group_id = result[0]
# Step 3: Upsert permission grant.
op.execute(
pg_insert(permission_grant_table)
.values(
group_id=group_id,
permission=permission_value,
grant_source="SYSTEM",
)
.on_conflict_do_nothing(index_elements=["group_id", "permission"])
)
def downgrade() -> None:
# Remove the default groups created by this migration.
# First remove user-group memberships that reference default groups
# to avoid FK violations, then delete the groups themselves.
default_group_ids = sa.select(user_group_table.c.id).where(
user_group_table.c.is_default == True # noqa: E712
)
op.execute(
sa.delete(user__user_group_table).where(
user__user_group_table.c.user_group_id.in_(default_group_ids)
)
)
op.execute(
sa.delete(user_group_table).where(
user_group_table.c.is_default == True # noqa: E712
)
)

View File

@@ -0,0 +1,84 @@
"""grant_basic_to_existing_groups
Grants the "basic" permission to all existing groups that don't already
have it. Every group should have at least "basic" so that its members
get basic access when effective_permissions is backfilled.
Revision ID: b4b7e1028dfd
Revises: b7bcc991d722
Create Date: 2026-03-30 16:15:17.093498
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b4b7e1028dfd"
down_revision = "b7bcc991d722"
branch_labels: str | None = None
depends_on: str | Sequence[str] | None = None
user_group = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("is_default", sa.Boolean),
)
permission_grant = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("grant_source", sa.String),
sa.column("is_deleted", sa.Boolean),
)
def upgrade() -> None:
conn = op.get_bind()
already_has_basic = (
sa.select(sa.literal(1))
.select_from(permission_grant)
.where(
permission_grant.c.group_id == user_group.c.id,
permission_grant.c.permission == "basic",
)
.exists()
)
groups_needing_basic = sa.select(
user_group.c.id,
sa.literal("basic").label("permission"),
sa.literal("SYSTEM").label("grant_source"),
sa.literal(False).label("is_deleted"),
).where(
user_group.c.is_default == sa.false(),
~already_has_basic,
)
conn.execute(
permission_grant.insert().from_select(
["group_id", "permission", "grant_source", "is_deleted"],
groups_needing_basic,
)
)
def downgrade() -> None:
conn = op.get_bind()
non_default_group_ids = sa.select(user_group.c.id).where(
user_group.c.is_default == sa.false()
)
conn.execute(
permission_grant.delete().where(
permission_grant.c.permission == "basic",
permission_grant.c.grant_source == "SYSTEM",
permission_grant.c.group_id.in_(non_default_group_ids),
)
)

View File

@@ -0,0 +1,116 @@
"""assign_users_to_default_groups
Revision ID: b7bcc991d722
Revises: 03d085c5c38d
Create Date: 2026-03-25 16:30:39.529301
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import insert as pg_insert
# revision identifiers, used by Alembic.
revision = "b7bcc991d722"
down_revision = "03d085c5c38d"
branch_labels = None
depends_on = None
# Reflect table structures for use in DML
user_group_table = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("name", sa.String),
sa.column("is_default", sa.Boolean),
)
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("role", sa.String),
sa.column("account_type", sa.String),
sa.column("is_active", sa.Boolean),
)
user__user_group_table = sa.table(
"user__user_group",
sa.column("user_group_id", sa.Integer),
sa.column("user_id", sa.Uuid),
)
def upgrade() -> None:
conn = op.get_bind()
# Look up default group IDs
admin_row = conn.execute(
sa.select(user_group_table.c.id).where(
user_group_table.c.name == "Admin",
user_group_table.c.is_default == True, # noqa: E712
)
).fetchone()
basic_row = conn.execute(
sa.select(user_group_table.c.id).where(
user_group_table.c.name == "Basic",
user_group_table.c.is_default == True, # noqa: E712
)
).fetchone()
if admin_row is None:
raise RuntimeError(
"Default 'Admin' group not found. "
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
)
if basic_row is None:
raise RuntimeError(
"Default 'Basic' group not found. "
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
)
# Users with role=admin → Admin group
# Exclude inactive placeholder/anonymous users that are not real users
admin_users = sa.select(
sa.literal(admin_row[0]).label("user_group_id"),
user_table.c.id.label("user_id"),
).where(
user_table.c.role == "ADMIN",
user_table.c.is_active == True, # noqa: E712
)
op.execute(
pg_insert(user__user_group_table)
.from_select(["user_group_id", "user_id"], admin_users)
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
)
# STANDARD users (non-admin) and SERVICE_ACCOUNT users (role=basic) → Basic group
# Exclude inactive placeholder/anonymous users that are not real users
basic_users = sa.select(
sa.literal(basic_row[0]).label("user_group_id"),
user_table.c.id.label("user_id"),
).where(
user_table.c.is_active == True, # noqa: E712
sa.or_(
sa.and_(
user_table.c.account_type == "STANDARD",
user_table.c.role != "ADMIN",
),
sa.and_(
user_table.c.account_type == "SERVICE_ACCOUNT",
user_table.c.role == "BASIC",
),
),
)
op.execute(
pg_insert(user__user_group_table)
.from_select(["user_group_id", "user_id"], basic_users)
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
)
def downgrade() -> None:
# Group memberships are left in place — removing them risks
# deleting memberships that existed before this migration.
pass

View File

@@ -0,0 +1,55 @@
"""add skipped to userfilestatus
Revision ID: d8cdfee5df80
Revises: 1d78c0ca7853
Create Date: 2026-04-01 10:47:12.593950
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d8cdfee5df80"
down_revision = "1d78c0ca7853"
branch_labels = None
depends_on = None
TABLE = "user_file"
COLUMN = "status"
CONSTRAINT_NAME = "ck_user_file_status"
OLD_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
NEW_VALUES = (
"PROCESSING",
"INDEXING",
"COMPLETED",
"SKIPPED",
"FAILED",
"CANCELED",
"DELETING",
)
def _drop_status_check_constraint() -> None:
inspector = sa.inspect(op.get_bind())
for constraint in inspector.get_check_constraints(TABLE):
if COLUMN in constraint.get("sqltext", ""):
constraint_name = constraint["name"]
if constraint_name is not None:
op.drop_constraint(constraint_name, TABLE, type_="check")
def upgrade() -> None:
_drop_status_check_constraint()
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
def downgrade() -> None:
op.execute(f"UPDATE {TABLE} SET {COLUMN} = 'COMPLETED' WHERE {COLUMN} = 'SKIPPED'")
_drop_status_check_constraint()
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")

View File

@@ -5,6 +5,7 @@ from onyx.background.celery.apps.primary import celery_app
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.hooks",
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cloud",

View File

@@ -55,6 +55,15 @@ ee_tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
ee_tasks_to_schedule = [
{
"name": "hook-execution-log-cleanup",
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
"schedule": timedelta(days=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,

View File

@@ -13,6 +13,7 @@ from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.provisioning import setup_tenant
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
@@ -29,9 +30,10 @@ from shared_configs.configs import TENANT_ID_PREFIX
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
_MAX_TENANTS_PER_RUN = 5
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
# Time limits sized for worst-case: provisioning up to _MAX_TENANTS_PER_RUN new tenants
# (~90s each) plus migrating up to TARGET_AVAILABLE_TENANTS pool tenants (~90s each).
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 20 # 20 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 25 # 25 minutes
@shared_task(
@@ -91,8 +93,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
if batch_size < tenants_to_provision:
task_logger.info(
f"Capping batch to {batch_size} "
f"(need {tenants_to_provision}, will catch up next cycle)"
f"Capping batch to {batch_size} (need {tenants_to_provision}, will catch up next cycle)"
)
provisioned = 0
@@ -103,12 +104,14 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
provisioned += 1
except Exception:
task_logger.exception(
f"Failed to provision tenant {i + 1}/{batch_size}, "
"continuing with remaining tenants"
f"Failed to provision tenant {i + 1}/{batch_size}, continuing with remaining tenants"
)
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
# Migrate any pool tenants that were provisioned before a new migration was deployed
_migrate_stale_pool_tenants()
except Exception:
task_logger.exception("Error in check_available_tenants task")
@@ -121,6 +124,46 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
)
def _migrate_stale_pool_tenants() -> None:
"""
Run alembic upgrade head on all pool tenants. Since alembic upgrade head is
idempotent, tenants already at head are a fast no-op. This ensures pool
tenants are always current so that signup doesn't hit schema mismatches
(e.g. missing columns added after the tenant was pre-provisioned).
"""
with get_session_with_shared_schema() as db_session:
pool_tenants = db_session.query(AvailableTenant).all()
tenant_ids = [t.tenant_id for t in pool_tenants]
if not tenant_ids:
return
task_logger.info(
f"Checking {len(tenant_ids)} pool tenant(s) for pending migrations"
)
for tenant_id in tenant_ids:
try:
run_alembic_migrations(tenant_id)
new_version = get_current_alembic_version(tenant_id)
with get_session_with_shared_schema() as db_session:
tenant = (
db_session.query(AvailableTenant)
.filter_by(tenant_id=tenant_id)
.first()
)
if tenant and tenant.alembic_version != new_version:
task_logger.info(
f"Migrated pool tenant {tenant_id}: {tenant.alembic_version} -> {new_version}"
)
tenant.alembic_version = new_version
db_session.commit()
except Exception:
task_logger.exception(
f"Failed to migrate pool tenant {tenant_id}, skipping"
)
def pre_provision_tenant() -> bool:
"""
Pre-provision a new tenant and store it in the NewAvailableTenant table.

View File

@@ -69,5 +69,7 @@ EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset(
"/admin/token-rate-limits",
# Evals
"/evals",
# Hook extensions
"/admin/hooks",
}
)

View File

@@ -19,6 +19,8 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import GrantSource
from onyx.db.enums import Permission
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
@@ -28,6 +30,7 @@ from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import PermissionGrant
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
@@ -36,6 +39,7 @@ from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.models import UserRole
from onyx.db.permissions import recompute_user_permissions__no_commit
from onyx.db.users import fetch_user_by_id
from onyx.utils.logger import setup_logger
@@ -255,6 +259,7 @@ def fetch_user_groups(
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
include_default: bool = True,
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -269,6 +274,7 @@ def fetch_user_groups(
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
include_default: If False, excludes system default groups (is_default=True).
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -276,6 +282,8 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
@@ -286,6 +294,7 @@ def fetch_user_groups_for_user(
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
include_default: bool = True,
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -295,6 +304,8 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
@@ -478,6 +489,16 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
db_session.add(db_user_group)
db_session.flush() # give the group an ID
# Every group gets the "basic" permission by default
db_session.add(
PermissionGrant(
group_id=db_user_group.id,
permission=Permission.BASIC_ACCESS,
grant_source=GrantSource.SYSTEM,
)
)
db_session.flush()
_add_user__user_group_relationships__no_commit(
db_session=db_session,
user_group_id=db_user_group.id,
@@ -489,6 +510,9 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
cc_pair_ids=user_group.cc_pair_ids,
)
for uid in user_group.user_ids:
recompute_user_permissions__no_commit(uid, db_session)
db_session.commit()
return db_user_group
@@ -796,6 +820,9 @@ def update_user_group(
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()
for uid in set(added_user_ids) | set(removed_user_ids):
recompute_user_permissions__no_commit(uid, db_session)
db_session.commit()
return db_user_group
@@ -835,6 +862,17 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
_check_user_group_is_modifiable(db_user_group)
# Collect affected user IDs before cleanup deletes the relationships
affected_user_ids = (
db_session.execute(
select(User__UserGroup.user_id).where(
User__UserGroup.user_group_id == user_group_id
)
)
.scalars()
.all()
)
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
@@ -863,6 +901,11 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
db_session=db_session, user_group_id=user_group_id
)
# Recompute permissions for affected users now that their
# membership in this group has been removed
for uid in affected_user_ids:
recompute_user_permissions__no_commit(uid, db_session)
db_user_group.is_up_to_date = False
db_user_group.is_up_for_deletion = True
db_session.commit()

View File

View File

@@ -0,0 +1,385 @@
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
Usage (Celery tasks and FastAPI handlers):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
response_type=QueryProcessingResponse,
)
if isinstance(result, HookSkipped):
# no active hook configured — continue with original behavior
...
elif isinstance(result, HookSoftFailed):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is a validated Pydantic model instance (response_type)
...
is_reachable update policy
--------------------------
``is_reachable`` on the Hook row is updated selectively — only when the outcome
carries meaningful signal about physical reachability:
NetworkError (DNS, connection refused) → False (cannot reach the server)
HTTP 401 / 403 → False (api_key revoked or invalid)
TimeoutException → None (server may be slow, skip write)
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
Unknown exception → None (no signal, skip write)
Non-JSON / non-dict response → None (server responded, skip write)
Success (2xx, valid dict) → True (confirmed reachable)
None means "leave the current value unchanged" — no DB round-trip is made.
DB session design
-----------------
The executor uses three sessions:
1. Caller's session (db_session) — used only for the hook lookup read. All
needed fields are extracted from the Hook object before the HTTP call, so
the caller's session is not held open during the external HTTP request.
2. Log session — a separate short-lived session opened after the HTTP call
completes to write the HookExecutionLog row on failure. Success runs are
not recorded. Committed independently of everything else.
3. Reachable session — a second short-lived session to update is_reachable on
the Hook. Kept separate from the log session so a concurrent hook deletion
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
prevent the execution log from being written. This update is best-effort.
"""
import json
import time
from typing import Any
from typing import TypeVar
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.hook import create_hook_execution_log__no_commit
from onyx.db.hook import get_non_deleted_hook_by_hook_point
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
T = TypeVar("T", bound=BaseModel)
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
class _HttpOutcome(BaseModel):
"""Structured result of an HTTP hook call, returned by _process_response."""
is_success: bool
updated_is_reachable: (
bool | None
) # True/False = write to DB, None = unchanged (skip write)
status_code: int | None
error_message: str | None
response_payload: dict[str, Any] | None
def _lookup_hook(
db_session: Session,
hook_point: HookPoint,
) -> Hook | HookSkipped:
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
No HTTP call is made and no DB writes are performed for any HookSkipped path.
There is nothing to log and no reachability information to update.
"""
if MULTI_TENANT:
return HookSkipped()
hook = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if hook is None or not hook.is_active:
return HookSkipped()
if not hook.endpoint_url:
return HookSkipped()
return hook
def _process_response(
*,
response: httpx.Response | None,
exc: Exception | None,
timeout: float,
) -> _HttpOutcome:
"""Process the result of an HTTP call and return a structured outcome.
Called after the client.post() try/except. If post() raised, exc is set and
response is None. Otherwise response is set and exc is None. Handles
raise_for_status(), JSON decoding, and the dict shape check.
"""
if exc is not None:
if isinstance(exc, httpx.NetworkError):
msg = f"Hook network error (endpoint unreachable): {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False,
status_code=None,
error_message=msg,
response_payload=None,
)
if isinstance(exc, httpx.TimeoutException):
msg = f"Hook timed out after {timeout}s: {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # timeout doesn't indicate unreachability
status_code=None,
error_message=msg,
response_payload=None,
)
msg = f"Hook call failed: {exc}"
logger.exception(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # unknown error — don't make assumptions
status_code=None,
error_message=msg,
response_payload=None,
)
if response is None:
raise ValueError(
"exactly one of response or exc must be non-None; both are None"
)
status_code = response.status_code
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
logger.warning(msg, exc_info=e)
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
# so the operator knows to update it. All other HTTP errors keep is_reachable
# as-is (server is up, the request just failed for application reasons).
auth_failed = e.response.status_code in (401, 403)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False if auth_failed else None,
status_code=status_code,
error_message=msg,
response_payload=None,
)
try:
response_payload = response.json()
except (json.JSONDecodeError, httpx.DecodingError) as e:
msg = f"Hook returned non-JSON response: {e}"
logger.warning(msg, exc_info=e)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
if not isinstance(response_payload, dict):
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
logger.warning(msg)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
return _HttpOutcome(
is_success=True,
updated_is_reachable=True,
status_code=status_code,
error_message=None,
response_payload=response_payload,
)
def _persist_result(
*,
hook_id: int,
outcome: _HttpOutcome,
duration_ms: int,
) -> None:
"""Write the execution log on failure and optionally update is_reachable, each
in its own session so a failure in one does not affect the other."""
# Only write the execution log on failure — success runs are not recorded.
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
# deleted between the initial lookup and here).
if not outcome.is_success:
try:
with get_session_with_current_tenant() as log_session:
create_hook_execution_log__no_commit(
db_session=log_session,
hook_id=hook_id,
is_success=False,
error_message=outcome.error_message,
status_code=outcome.status_code,
duration_ms=duration_ms,
)
log_session.commit()
except Exception:
logger.exception(
f"Failed to persist hook execution log for hook_id={hook_id}"
)
# Update is_reachable separately — best-effort, non-critical.
# None means the value is unchanged (set by the caller to skip the no-op write).
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
# concurrently deleted, so keep this isolated from the log write above.
if outcome.updated_is_reachable is not None:
try:
with get_session_with_current_tenant() as reachable_session:
update_hook__no_commit(
db_session=reachable_session,
hook_id=hook_id,
is_reachable=outcome.updated_is_reachable,
)
reachable_session.commit()
except Exception:
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def _execute_hook_inner(
hook: Hook,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
"active hooks without an endpoint_url must be rejected by _lookup_hook"
)
start = time.monotonic()
response: httpx.Response | None = None
exc: Exception | None = None
try:
api_key: str | None = (
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
)
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against response_type.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: T | None = None
if outcome.is_success and outcome.response_payload is not None:
try:
validated_model = response_type.model_validate(outcome.response_payload)
except ValidationError as e:
msg = (
f"Hook response failed validation against {response_type.__name__}: {e}"
)
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
outcome = outcome.model_copy(update={"updated_is_reachable": None})
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
if not outcome.is_success:
if fail_strategy == HookFailStrategy.HARD:
raise OnyxError(
OnyxErrorCode.HOOK_EXECUTION_FAILED,
outcome.error_message or "Hook execution failed.",
)
logger.warning(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if validated_model is None:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"validated_model is None for successful hook call (hook_id={hook_id})",
)
return validated_model
def _execute_hook_impl(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSkipped | HookSoftFailed:
"""EE implementation — loaded by CE's execute_hook via fetch_versioned_implementation.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
"""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
fail_strategy = hook.fail_strategy
hook_id = hook.id
try:
return _execute_hook_inner(hook, payload, response_type)
except Exception:
if fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
)
return HookSoftFailed()
raise

View File

@@ -15,6 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
basic_router as enterprise_settings_router,
)
from ee.onyx.server.evals.api import router as evals_router
from ee.onyx.server.features.hooks.api import router as hook_router
from ee.onyx.server.license.api import router as license_router
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.license_enforcement import (
@@ -138,6 +139,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, ee_oauth_router)
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
include_router_with_global_prefix_prepended(application, evals_router)
include_router_with_global_prefix_prepended(application, hook_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(

View File

@@ -52,11 +52,13 @@ from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccountType
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -486,6 +488,7 @@ def create_user(
email=email,
hashed_password=_pw_helper.hash(_pw_helper.generate()),
role=UserRole.BASIC,
account_type=AccountType.STANDARD,
is_active=user_resource.active,
is_verified=True,
personal_name=personal_name,
@@ -506,13 +509,25 @@ def create_user(
scim_username=scim_username,
fields=fields,
)
dal.commit()
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"User with email {email} already has a SCIM mapping"
)
# Assign user to default group BEFORE commit so everything is atomic.
# If this fails, the entire user creation rolls back and IdP can retry.
try:
assign_user_to_default_groups__no_commit(db_session, user)
except Exception:
dal.rollback()
logger.exception(f"Failed to assign SCIM user {email} to default groups")
return _scim_error_response(
500, f"Failed to assign user {email} to default group"
)
dal.commit()
return _scim_resource_response(
provider.build_user_resource(
user,

View File

@@ -99,6 +99,26 @@ async def get_or_provision_tenant(
tenant_id = await get_available_tenant()
if tenant_id:
# Run migrations to ensure the pre-provisioned tenant schema is current.
# Pool tenants may have been created before a new migration was deployed.
# Capture as a non-optional local so mypy can type the lambda correctly.
_tenant_id: str = tenant_id
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(
None, lambda: run_alembic_migrations(_tenant_id)
)
except Exception:
# The tenant was already dequeued from the pool — roll it back so
# it doesn't end up orphaned (schema exists, but not assigned to anyone).
logger.exception(
f"Migration failed for pre-provisioned tenant {_tenant_id}; rolling back"
)
try:
await rollback_tenant_provisioning(_tenant_id)
except Exception:
logger.exception(f"Failed to rollback orphaned tenant {_tenant_id}")
raise
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")

View File

@@ -43,12 +43,16 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
@router.get("/admin/user-group")
def list_user_groups(
include_default: bool = False,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session, only_up_to_date=False, eager_load_for_snapshot=True
db_session,
only_up_to_date=False,
eager_load_for_snapshot=True,
include_default=include_default,
)
else:
user_groups = fetch_user_groups_for_user(
@@ -56,27 +60,50 @@ def list_user_groups(
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
include_default=include_default,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]
@router.get("/user-groups/minimal")
def list_minimal_user_groups(
include_default: bool = False,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[MinimalUserGroupSnapshot]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
user_groups = fetch_user_groups(
db_session,
only_up_to_date=False,
include_default=include_default,
)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
include_default=include_default,
)
return [
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
]
@router.get("/admin/user-group/{user_group_id}/permissions")
def get_user_group_permissions(
user_group_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[str]:
group = fetch_user_group(db_session, user_group_id)
if group is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
return [
grant.permission.value
for grant in group.permission_grants
if not grant.is_deleted
]
@router.post("/admin/user-group")
def create_user_group(
user_group: UserGroupCreate,
@@ -100,6 +127,9 @@ def rename_user_group_endpoint(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
group = fetch_user_group(db_session, rename_request.id)
if group and group.is_default:
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot rename a default system group.")
try:
return UserGroup.from_model(
rename_user_group(
@@ -185,6 +215,9 @@ def delete_user_group(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
group = fetch_user_group(db_session, user_group_id)
if group and group.is_default:
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot delete a default system group.")
try:
prepare_user_group_for_deletion(db_session, user_group_id)
except ValueError as e:

View File

@@ -22,6 +22,7 @@ class UserGroup(BaseModel):
personas: list[PersonaSnapshot]
is_up_to_date: bool
is_up_for_deletion: bool
is_default: bool
@classmethod
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
@@ -74,18 +75,21 @@ class UserGroup(BaseModel):
],
is_up_to_date=user_group_model.is_up_to_date,
is_up_for_deletion=user_group_model.is_up_for_deletion,
is_default=user_group_model.is_default,
)
class MinimalUserGroupSnapshot(BaseModel):
id: int
name: str
is_default: bool
@classmethod
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
return cls(
id=user_group_model.id,
name=user_group_model.name,
is_default=user_group_model.is_default,
)

View File

@@ -100,6 +100,7 @@ def get_model_app() -> FastAPI:
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -0,0 +1,110 @@
"""
Permission resolution for group-based authorization.
Granted permissions are stored as a JSONB column on the User table and
loaded for free with every auth query. Implied permissions are expanded
at read time — only directly granted permissions are persisted.
"""
from collections.abc import Callable
from collections.abc import Coroutine
from typing import Any
from fastapi import Depends
from onyx.auth.users import current_user
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
logger = setup_logger()
ALL_PERMISSIONS: frozenset[str] = frozenset(p.value for p in Permission)
# Implication map: granted permission -> set of permissions it implies.
IMPLIED_PERMISSIONS: dict[str, set[str]] = {
Permission.ADD_AGENTS.value: {Permission.READ_AGENTS.value},
Permission.MANAGE_AGENTS.value: {
Permission.ADD_AGENTS.value,
Permission.READ_AGENTS.value,
},
Permission.MANAGE_DOCUMENT_SETS.value: {
Permission.READ_DOCUMENT_SETS.value,
Permission.READ_CONNECTORS.value,
},
Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value},
Permission.MANAGE_CONNECTORS.value: {
Permission.ADD_CONNECTORS.value,
Permission.READ_CONNECTORS.value,
},
Permission.MANAGE_USER_GROUPS.value: {
Permission.READ_CONNECTORS.value,
Permission.READ_DOCUMENT_SETS.value,
Permission.READ_AGENTS.value,
Permission.READ_USERS.value,
},
}
def resolve_effective_permissions(granted: set[str]) -> set[str]:
"""Expand granted permissions with their implied permissions.
If "admin" is present, returns all 19 permissions.
"""
if Permission.FULL_ADMIN_PANEL_ACCESS.value in granted:
return set(ALL_PERMISSIONS)
effective = set(granted)
changed = True
while changed:
changed = False
for perm in list(effective):
implied = IMPLIED_PERMISSIONS.get(perm)
if implied and not implied.issubset(effective):
effective |= implied
changed = True
return effective
def get_effective_permissions(user: User) -> set[Permission]:
"""Read granted permissions from the column and expand implied permissions."""
granted: set[Permission] = set()
for p in user.effective_permissions:
try:
granted.add(Permission(p))
except ValueError:
logger.warning(f"Skipping unknown permission '{p}' for user {user.id}")
if Permission.FULL_ADMIN_PANEL_ACCESS in granted:
return set(Permission)
expanded = resolve_effective_permissions({p.value for p in granted})
return {Permission(p) for p in expanded}
def require_permission(
required: Permission,
) -> Callable[..., Coroutine[Any, Any, User]]:
"""FastAPI dependency factory for permission-based access control.
Usage:
@router.get("/endpoint")
def endpoint(user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS))):
...
"""
async def dependency(user: User = Depends(current_user)) -> User:
effective = get_effective_permissions(user)
if Permission.FULL_ADMIN_PANEL_ACCESS in effective:
return user
if required not in effective:
raise OnyxError(
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
"You do not have the required permissions for this action.",
)
return user
return dependency

View File

@@ -5,6 +5,8 @@ from typing import Any
from fastapi_users import schemas
from typing_extensions import override
from onyx.db.enums import AccountType
class UserRole(str, Enum):
"""
@@ -41,6 +43,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
account_type: AccountType = AccountType.STANDARD
tenant_id: str | None = None
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
# Excluded from create_update_dict so it never reaches the DB layer
@@ -50,12 +53,16 @@ class UserCreate(schemas.BaseUserCreate):
def create_update_dict(self) -> dict[str, Any]:
d = super().create_update_dict()
d.pop("captcha_token", None)
# Force STANDARD for self-registration; only trusted paths
# (SCIM, API key creation) supply a different account_type directly.
d["account_type"] = AccountType.STANDARD
return d
@override
def create_update_dict_superuser(self) -> dict[str, Any]:
d = super().create_update_dict_superuser()
d.pop("captcha_token", None)
d.setdefault("account_type", self.account_type)
return d

View File

@@ -120,11 +120,13 @@ from onyx.db.engine.async_sql_engine import get_async_session
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.enums import AccountType
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.db.users import get_user_by_email
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import log_onyx_error
@@ -694,6 +696,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
"email": account_email,
"hashed_password": self.password_helper.hash(password),
"is_verified": is_verified_by_default,
"account_type": AccountType.STANDARD,
}
user = await self.user_db.create(user_dict)
@@ -743,14 +746,23 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
with get_session_with_current_tenant() as sync_db:
enforce_seat_limit(sync_db)
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"role": UserRole.BASIC,
**({"is_active": True} if not user.is_active else {}),
},
)
# Upgrade the user and assign default groups in a single
# transaction so neither change is visible without the other.
was_inactive = not user.is_active
with get_session_with_current_tenant() as sync_db:
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
if sync_user:
sync_user.is_verified = is_verified_by_default
sync_user.role = UserRole.BASIC
sync_user.account_type = AccountType.STANDARD
if was_inactive:
sync_user.is_active = True
assign_user_to_default_groups__no_commit(sync_db, sync_user)
sync_db.commit()
# Refresh the async user object so downstream code
# (e.g. oidc_expiry check) sees the updated fields.
user = await self.user_db.get(user.id) # type: ignore[arg-type]
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
@@ -836,6 +848,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
event=MilestoneRecordType.TENANT_CREATED,
)
# Assign user to the appropriate default group (Admin or Basic).
# Must happen inside the try block while tenant context is active,
# otherwise get_session_with_current_tenant() targets the wrong schema.
is_admin = user_count == 1 or user.email in get_default_admin_user_emails()
with get_session_with_current_tenant() as db_session:
assign_user_to_default_groups__no_commit(
db_session, user, is_admin=is_admin
)
db_session.commit()
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@@ -1554,6 +1576,7 @@ def get_anonymous_user() -> User:
is_verified=True,
is_superuser=False,
role=UserRole.LIMITED,
account_type=AccountType.ANONYMOUS,
use_memories=False,
enable_memory_tool=False,
)

View File

@@ -20,6 +20,7 @@ from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx import __version__
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
@@ -65,6 +66,7 @@ if SENTRY_DSN:
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:
@@ -515,7 +517,8 @@ def reset_tenant_id(
def wait_for_vespa_or_shutdown(
sender: Any, **kwargs: Any # noqa: ARG001
sender: Any, # noqa: ARG001
**kwargs: Any, # noqa: ARG001
) -> None: # noqa: ARG001
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""

View File

@@ -317,7 +317,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.hierarchyfetching",
"onyx.background.celery.tasks.hooks",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",

View File

@@ -14,7 +14,6 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.hooks.utils import HOOKS_AVAILABLE
from shared_configs.configs import MULTI_TENANT
# choosing 15 minutes because it roughly gives us enough time to process many tasks
@@ -362,19 +361,6 @@ if not MULTI_TENANT:
tasks_to_schedule.extend(beat_task_templates)
if HOOKS_AVAILABLE:
tasks_to_schedule.append(
{
"name": "hook-execution-log-cleanup",
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
"schedule": timedelta(days=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def generate_cloud_tasks(
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float

View File

@@ -9,6 +9,7 @@ from celery import Celery
from celery import shared_task
from celery import Task
from onyx import __version__
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
@@ -137,6 +138,7 @@ def _docfetching_task(
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -319,6 +319,11 @@ def monitor_indexing_attempt_progress(
)
current_db_time = get_db_current_time(db_session)
total_batches: int | str = (
coordination_status.total_batches
if coordination_status.total_batches is not None
else "?"
)
if coordination_status.found:
task_logger.info(
f"Indexing attempt progress: "
@@ -326,7 +331,7 @@ def monitor_indexing_attempt_progress(
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id} "
f"completed_batches={coordination_status.completed_batches} "
f"total_batches={coordination_status.total_batches or '?'} "
f"total_batches={total_batches} "
f"total_docs={coordination_status.total_docs} "
f"total_failures={coordination_status.total_failures}"
f"elapsed={(current_db_time - attempt.time_created).seconds}"
@@ -410,7 +415,7 @@ def check_indexing_completion(
logger.info(
f"Indexing status: "
f"indexing_completed={indexing_completed} "
f"batches_processed={batches_processed}/{batches_total or '?'} "
f"batches_processed={batches_processed}/{batches_total if batches_total is not None else '?'} "
f"total_docs={coordination_status.total_docs} "
f"total_chunks={coordination_status.total_chunks} "
f"total_failures={coordination_status.total_failures}"

View File

@@ -805,6 +805,10 @@ MINI_CHUNK_SIZE = 150
# This is the number of regular chunks per large chunk
LARGE_CHUNK_RATIO = 4
# The maximum number of chunks that can be held for 1 document processing batch
# The purpose of this is to set an upper bound on memory usage
MAX_CHUNKS_PER_DOC_BATCH = int(os.environ.get("MAX_CHUNKS_PER_DOC_BATCH") or 1000)
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
@@ -1075,7 +1079,6 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
HOOK_ENABLED = os.environ.get("HOOK_ENABLED", "").lower() == "true"
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"

View File

@@ -212,6 +212,7 @@ class DocumentSource(str, Enum):
PRODUCTBOARD = "productboard"
FILE = "file"
CODA = "coda"
CANVAS = "canvas"
NOTION = "notion"
ZULIP = "zulip"
LINEAR = "linear"
@@ -277,6 +278,7 @@ class NotificationType(str, Enum):
RELEASE_NOTES = "release_notes"
ASSISTANT_FILES_READY = "assistant_files_ready"
FEATURE_ANNOUNCEMENT = "feature_announcement"
USER_GROUP_ASSIGNMENT_FAILED = "user_group_assignment_failed"
class BlobType(str, Enum):
@@ -672,6 +674,7 @@ DocumentSourceDescription: dict[DocumentSource, str] = {
DocumentSource.SLAB: "slab data",
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
DocumentSource.FILE: "files",
DocumentSource.CANVAS: "canvas lms - courses, pages, assignments, and announcements",
DocumentSource.CODA: "coda - team workspace with docs, tables, and pages",
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
project management, and collaboration tools into a single, customizable platform",

View File

@@ -0,0 +1,32 @@
"""
Permissioning / AccessControl logic for Canvas courses.
CE stub — returns None (no permissions). The EE implementation is loaded
at runtime via ``fetch_versioned_implementation``.
"""
from collections.abc import Callable
from typing import cast
from onyx.access.models import ExternalAccess
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version
def get_course_permissions(
canvas_client: CanvasApiClient,
course_id: int,
) -> ExternalAccess | None:
if not global_version.is_ee_version():
return None
ee_get_course_permissions = cast(
Callable[[CanvasApiClient, int], ExternalAccess | None],
fetch_versioned_implementation(
"onyx.external_permissions.canvas.access",
"get_course_permissions",
),
)
return ee_get_course_permissions(canvas_client, course_id)

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import logging
import re
from collections.abc import Iterator
from typing import Any
from urllib.parse import urlparse
@@ -190,3 +191,22 @@ class CanvasApiClient:
if clean_endpoint:
final_url += "/" + clean_endpoint
return final_url
def paginate(
self,
endpoint: str,
params: dict[str, Any] | None = None,
) -> Iterator[list[Any]]:
"""Yield each page of results, following Link-header pagination.
Makes the first request with endpoint + params, then follows
next_url from Link headers for subsequent pages.
"""
response, next_url = self.get(endpoint, params=params)
while True:
if not response:
break
yield response
if not next_url:
break
response, next_url = self.get(full_url=next_url)

View File

@@ -1,17 +1,82 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Literal
from typing import NoReturn
from typing import TypeAlias
from pydantic import BaseModel
from retry import retry
from typing_extensions import override
from onyx.access.models import ExternalAccess
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.canvas.access import get_course_permissions
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
"""Map Canvas API errors to connector framework exceptions."""
if e.status_code == 401:
raise CredentialExpiredError(
"Canvas API token is invalid or expired (HTTP 401)."
)
elif e.status_code == 403:
raise InsufficientPermissionsError(
"Canvas API token does not have sufficient permissions (HTTP 403)."
)
elif e.status_code == 429:
raise ConnectorValidationError(
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
)
elif e.status_code >= 500:
raise UnexpectedValidationError(
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
)
else:
raise ConnectorValidationError(
f"Canvas API error (status={e.status_code}): {e}"
)
class CanvasCourse(BaseModel):
id: int
name: str
course_code: str
created_at: str
workflow_state: str
name: str | None = None
course_code: str | None = None
created_at: str | None = None
workflow_state: str | None = None
@classmethod
def from_api(cls, payload: dict[str, Any]) -> "CanvasCourse":
return cls(
id=payload["id"],
name=payload.get("name"),
course_code=payload.get("course_code"),
created_at=payload.get("created_at"),
workflow_state=payload.get("workflow_state"),
)
class CanvasPage(BaseModel):
@@ -19,10 +84,22 @@ class CanvasPage(BaseModel):
url: str
title: str
body: str | None = None
created_at: str
updated_at: str
created_at: str | None = None
updated_at: str | None = None
course_id: int
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasPage":
return cls(
page_id=payload["page_id"],
url=payload["url"],
title=payload["title"],
body=payload.get("body"),
created_at=payload.get("created_at"),
updated_at=payload.get("updated_at"),
course_id=course_id,
)
class CanvasAssignment(BaseModel):
id: int
@@ -30,10 +107,23 @@ class CanvasAssignment(BaseModel):
description: str | None = None
html_url: str
course_id: int
created_at: str
updated_at: str
created_at: str | None = None
updated_at: str | None = None
due_at: str | None = None
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAssignment":
return cls(
id=payload["id"],
name=payload["name"],
description=payload.get("description"),
html_url=payload["html_url"],
course_id=course_id,
created_at=payload.get("created_at"),
updated_at=payload.get("updated_at"),
due_at=payload.get("due_at"),
)
class CanvasAnnouncement(BaseModel):
id: int
@@ -43,6 +133,17 @@ class CanvasAnnouncement(BaseModel):
posted_at: str | None = None
course_id: int
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAnnouncement":
return cls(
id=payload["id"],
title=payload["title"],
message=payload.get("message"),
html_url=payload["html_url"],
posted_at=payload.get("posted_at"),
course_id=course_id,
)
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
@@ -72,3 +173,286 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
self.current_course_index += 1
self.stage = "pages"
self.next_url = None
class CanvasConnector(
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
SlimConnectorWithPermSync,
):
def __init__(
self,
canvas_base_url: str,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.canvas_base_url = canvas_base_url.rstrip("/").removesuffix("/api/v1")
self.batch_size = batch_size
self._canvas_client: CanvasApiClient | None = None
self._course_permissions_cache: dict[int, ExternalAccess | None] = {}
@property
def canvas_client(self) -> CanvasApiClient:
if self._canvas_client is None:
raise ConnectorMissingCredentialError("Canvas")
return self._canvas_client
def _get_course_permissions(self, course_id: int) -> ExternalAccess | None:
"""Get course permissions with caching."""
if course_id not in self._course_permissions_cache:
self._course_permissions_cache[course_id] = get_course_permissions(
canvas_client=self.canvas_client,
course_id=course_id,
)
return self._course_permissions_cache[course_id]
@retry(tries=3, delay=1, backoff=2)
def _list_courses(self) -> list[CanvasCourse]:
"""Fetch all courses accessible to the authenticated user."""
logger.debug("Fetching Canvas courses")
courses: list[CanvasCourse] = []
for page in self.canvas_client.paginate(
"courses", params={"per_page": "100", "state[]": "available"}
):
courses.extend(CanvasCourse.from_api(c) for c in page)
return courses
@retry(tries=3, delay=1, backoff=2)
def _list_pages(self, course_id: int) -> list[CanvasPage]:
"""Fetch all pages for a given course."""
logger.debug(f"Fetching pages for course {course_id}")
pages: list[CanvasPage] = []
for page in self.canvas_client.paginate(
f"courses/{course_id}/pages",
params={"per_page": "100", "include[]": "body", "published": "true"},
):
pages.extend(CanvasPage.from_api(p, course_id=course_id) for p in page)
return pages
@retry(tries=3, delay=1, backoff=2)
def _list_assignments(self, course_id: int) -> list[CanvasAssignment]:
"""Fetch all assignments for a given course."""
logger.debug(f"Fetching assignments for course {course_id}")
assignments: list[CanvasAssignment] = []
for page in self.canvas_client.paginate(
f"courses/{course_id}/assignments",
params={"per_page": "100", "published": "true"},
):
assignments.extend(
CanvasAssignment.from_api(a, course_id=course_id) for a in page
)
return assignments
@retry(tries=3, delay=1, backoff=2)
def _list_announcements(self, course_id: int) -> list[CanvasAnnouncement]:
"""Fetch all announcements for a given course."""
logger.debug(f"Fetching announcements for course {course_id}")
announcements: list[CanvasAnnouncement] = []
for page in self.canvas_client.paginate(
"announcements",
params={
"per_page": "100",
"context_codes[]": f"course_{course_id}",
"active_only": "true",
},
):
announcements.extend(
CanvasAnnouncement.from_api(a, course_id=course_id) for a in page
)
return announcements
def _build_document(
self,
doc_id: str,
link: str,
text: str,
semantic_identifier: str,
doc_updated_at: datetime | None,
course_id: int,
doc_type: str,
) -> Document:
"""Build a Document with standard Canvas fields."""
return Document(
id=doc_id,
sections=cast(
list[TextSection | ImageSection],
[TextSection(link=link, text=text)],
),
source=DocumentSource.CANVAS,
semantic_identifier=semantic_identifier,
doc_updated_at=doc_updated_at,
metadata={"course_id": str(course_id), "type": doc_type},
)
def _convert_page_to_document(self, page: CanvasPage) -> Document:
"""Convert a Canvas page to a Document."""
link = f"{self.canvas_base_url}/courses/{page.course_id}/pages/{page.url}"
text_parts = [page.title]
body_text = parse_html_page_basic(page.body) if page.body else ""
if body_text:
text_parts.append(body_text)
doc_updated_at = (
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
timezone.utc
)
if page.updated_at
else None
)
document = self._build_document(
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
link=link,
text="\n\n".join(text_parts),
semantic_identifier=page.title or f"Page {page.page_id}",
doc_updated_at=doc_updated_at,
course_id=page.course_id,
doc_type="page",
)
return document
def _convert_assignment_to_document(self, assignment: CanvasAssignment) -> Document:
"""Convert a Canvas assignment to a Document."""
text_parts = [assignment.name]
desc_text = (
parse_html_page_basic(assignment.description)
if assignment.description
else ""
)
if desc_text:
text_parts.append(desc_text)
if assignment.due_at:
due_dt = datetime.fromisoformat(
assignment.due_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
doc_updated_at = (
datetime.fromisoformat(
assignment.updated_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if assignment.updated_at
else None
)
document = self._build_document(
doc_id=f"canvas-assignment-{assignment.course_id}-{assignment.id}",
link=assignment.html_url,
text="\n\n".join(text_parts),
semantic_identifier=assignment.name or f"Assignment {assignment.id}",
doc_updated_at=doc_updated_at,
course_id=assignment.course_id,
doc_type="assignment",
)
return document
def _convert_announcement_to_document(
self, announcement: CanvasAnnouncement
) -> Document:
"""Convert a Canvas announcement to a Document."""
text_parts = [announcement.title]
msg_text = (
parse_html_page_basic(announcement.message) if announcement.message else ""
)
if msg_text:
text_parts.append(msg_text)
doc_updated_at = (
datetime.fromisoformat(
announcement.posted_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if announcement.posted_at
else None
)
document = self._build_document(
doc_id=f"canvas-announcement-{announcement.course_id}-{announcement.id}",
link=announcement.html_url,
text="\n\n".join(text_parts),
semantic_identifier=announcement.title or f"Announcement {announcement.id}",
doc_updated_at=doc_updated_at,
course_id=announcement.course_id,
doc_type="announcement",
)
return document
@override
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load and validate Canvas credentials."""
access_token = credentials.get("canvas_access_token")
if not access_token:
raise ConnectorMissingCredentialError("Canvas")
try:
client = CanvasApiClient(
bearer_token=access_token,
canvas_base_url=self.canvas_base_url,
)
client.get("courses", params={"per_page": "1"})
except ValueError as e:
raise ConnectorValidationError(f"Invalid Canvas base URL: {e}")
except OnyxError as e:
_handle_canvas_api_error(e)
self._canvas_client = client
return None
@override
def validate_connector_settings(self) -> None:
"""Validate Canvas connector settings by testing API access."""
try:
self.canvas_client.get("courses", params={"per_page": "1"})
logger.info("Canvas connector settings validated successfully")
except OnyxError as e:
_handle_canvas_api_error(e)
except ConnectorMissingCredentialError:
raise
except Exception as exc:
raise UnexpectedValidationError(
f"Unexpected error during Canvas settings validation: {exc}"
)
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
# TODO(benwu408): implemented in PR4 (perm sync)
raise NotImplementedError

View File

@@ -11,11 +11,13 @@ from discord import Client
from discord.channel import TextChannel
from discord.channel import Thread
from discord.enums import MessageType
from discord.errors import LoginFailure
from discord.flags import Intents
from discord.message import Message as DiscordMessage
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import CredentialInvalidError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -209,8 +211,19 @@ def _manage_async_retrieval(
intents = Intents.default()
intents.message_content = True
async with Client(intents=intents) as discord_client:
asyncio.create_task(discord_client.start(token))
await discord_client.wait_until_ready()
start_task = asyncio.create_task(discord_client.start(token))
ready_task = asyncio.create_task(discord_client.wait_until_ready())
done, _ = await asyncio.wait(
{start_task, ready_task},
return_when=asyncio.FIRST_COMPLETED,
)
# start() runs indefinitely once connected, so it only lands
# in `done` when login/connection failed — propagate the error.
if start_task in done:
ready_task.cancel()
start_task.result()
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
discord_client=discord_client,
@@ -276,6 +289,19 @@ class DiscordConnector(PollConnector, LoadConnector):
self._discord_bot_token = credentials["discord_bot_token"]
return None
def validate_connector_settings(self) -> None:
loop = asyncio.new_event_loop()
try:
client = Client(intents=Intents.default())
try:
loop.run_until_complete(client.login(self.discord_bot_token))
except LoginFailure as e:
raise CredentialInvalidError(f"Invalid Discord bot token: {e}")
finally:
loop.run_until_complete(client.close())
finally:
loop.close()
def _manage_doc_batching(
self,
start: datetime | None = None,

View File

@@ -72,6 +72,10 @@ CONNECTOR_CLASS_MAP = {
module_path="onyx.connectors.coda.connector",
class_name="CodaConnector",
),
DocumentSource.CANVAS: ConnectorMapping(
module_path="onyx.connectors.canvas.connector",
class_name="CanvasConnector",
),
DocumentSource.NOTION: ConnectorMapping(
module_path="onyx.connectors.notion.connector",
class_name="NotionConnector",

View File

@@ -11,14 +11,19 @@ from onyx.auth.api_key import ApiKeyDescriptor
from onyx.auth.api_key import build_displayable_api_key
from onyx.auth.api_key import generate_api_key
from onyx.auth.api_key import hash_api_key
from onyx.auth.schemas import UserRole
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.enums import AccountType
from onyx.db.models import ApiKey
from onyx.db.models import User
from onyx.server.api_key.models import APIKeyArgs
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def get_api_key_email_pattern() -> str:
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
@@ -87,6 +92,7 @@ def insert_api_key(
is_superuser=False,
is_verified=True,
role=api_key_args.role,
account_type=AccountType.SERVICE_ACCOUNT,
)
db_session.add(api_key_user_row)
@@ -99,7 +105,21 @@ def insert_api_key(
)
db_session.add(api_key_row)
# Assign the API key virtual user to the appropriate default group
# before commit so everything is atomic.
# LIMITED role service accounts should have no group membership.
# Late import to avoid circular dependency (api_key <- users <- api_key).
if api_key_args.role != UserRole.LIMITED:
from onyx.db.users import assign_user_to_default_groups__no_commit
assign_user_to_default_groups__no_commit(
db_session,
api_key_user_row,
is_admin=(api_key_args.role == UserRole.ADMIN),
)
db_session.commit()
return ApiKeyDescriptor(
api_key_id=api_key_row.id,
api_key_role=api_key_user_row.role,

View File

@@ -8,7 +8,6 @@ from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import nullsfirst
from sqlalchemy import or_
@@ -132,32 +131,47 @@ def get_chat_sessions_by_user(
if before is not None:
stmt = stmt.where(ChatSession.time_updated < before)
if limit:
stmt = stmt.limit(limit)
if project_id is not None:
stmt = stmt.where(ChatSession.project_id == project_id)
elif only_non_project_chats:
stmt = stmt.where(ChatSession.project_id.is_(None))
if not include_failed_chats:
non_system_message_exists_subq = (
exists()
.where(ChatMessage.chat_session_id == ChatSession.id)
.where(ChatMessage.message_type != MessageType.SYSTEM)
.correlate(ChatSession)
)
# Leeway for newly created chats that don't have messages yet
time = datetime.now(timezone.utc) - timedelta(minutes=5)
recently_created = ChatSession.time_created >= time
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
# When filtering out failed chats, we apply the limit in Python after
# filtering rather than in SQL, since the post-filter may remove rows.
if limit and include_failed_chats:
stmt = stmt.limit(limit)
result = db_session.execute(stmt)
chat_sessions = result.scalars().all()
chat_sessions = list(result.scalars().all())
return list(chat_sessions)
if not include_failed_chats and chat_sessions:
# Filter out "failed" sessions (those with only SYSTEM messages)
# using a separate efficient query instead of a correlated EXISTS
# subquery, which causes full sequential scans of chat_message.
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
if session_ids:
valid_session_ids_stmt = (
select(ChatMessage.chat_session_id)
.where(ChatMessage.chat_session_id.in_(session_ids))
.where(ChatMessage.message_type != MessageType.SYSTEM)
.distinct()
)
valid_session_ids = set(
db_session.execute(valid_session_ids_stmt).scalars().all()
)
chat_sessions = [
cs
for cs in chat_sessions
if cs.time_created >= leeway or cs.id in valid_session_ids
]
if limit:
chat_sessions = chat_sessions[:limit]
return chat_sessions
def delete_orphaned_search_docs(db_session: Session) -> None:

View File

@@ -13,19 +13,19 @@ class AccountType(str, PyEnum):
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
"""
STANDARD = "standard"
BOT = "bot"
EXT_PERM_USER = "ext_perm_user"
SERVICE_ACCOUNT = "service_account"
ANONYMOUS = "anonymous"
STANDARD = "STANDARD"
BOT = "BOT"
EXT_PERM_USER = "EXT_PERM_USER"
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
ANONYMOUS = "ANONYMOUS"
class GrantSource(str, PyEnum):
"""How a permission grant was created."""
USER = "user"
SCIM = "scim"
SYSTEM = "system"
USER = "USER"
SCIM = "SCIM"
SYSTEM = "SYSTEM"
class IndexingStatus(str, PyEnum):
@@ -215,6 +215,7 @@ class UserFileStatus(str, PyEnum):
PROCESSING = "PROCESSING"
INDEXING = "INDEXING"
COMPLETED = "COMPLETED"
SKIPPED = "SKIPPED"
FAILED = "FAILED"
CANCELED = "CANCELED"
DELETING = "DELETING"

View File

@@ -305,8 +305,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
)
account_type: Mapped[AccountType | None] = mapped_column(
Enum(AccountType, native_enum=False), nullable=True
account_type: Mapped[AccountType] = mapped_column(
Enum(AccountType, native_enum=False),
nullable=False,
default=AccountType.STANDARD,
server_default="STANDARD",
)
"""
@@ -353,6 +356,13 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
postgresql.JSONB(), nullable=True, default=None
)
effective_permissions: Mapped[list[str]] = mapped_column(
postgresql.JSONB(),
nullable=False,
default=list,
server_default=text("'[]'::jsonb"),
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
TIMESTAMPAware(timezone=True), nullable=True
)
@@ -4016,7 +4026,12 @@ class PermissionGrant(Base):
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
)
permission: Mapped[Permission] = mapped_column(
Enum(Permission, native_enum=False), nullable=False
Enum(
Permission,
native_enum=False,
values_callable=lambda x: [e.value for e in x],
),
nullable=False,
)
grant_source: Mapped[GrantSource] = mapped_column(
Enum(GrantSource, native_enum=False), nullable=False

View File

@@ -3,6 +3,7 @@ from datetime import timezone
from uuid import UUID
from sqlalchemy import cast
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import insert
@@ -90,9 +91,18 @@ def get_notifications(
notif_type: NotificationType | None = None,
include_dismissed: bool = True,
) -> list[Notification]:
query = select(Notification).where(
Notification.user_id == user.id if user else Notification.user_id.is_(None)
)
if user is None:
user_filter = Notification.user_id.is_(None)
elif user.role == UserRole.ADMIN:
# Admins see their own notifications AND admin-targeted ones (user_id IS NULL)
user_filter = or_(
Notification.user_id == user.id,
Notification.user_id.is_(None),
)
else:
user_filter = Notification.user_id == user.id
query = select(Notification).where(user_filter)
if not include_dismissed:
query = query.where(Notification.dismissed.is_(False))
if notif_type:

View File

@@ -0,0 +1,97 @@
"""
DB operations for recomputing user effective_permissions.
These live in onyx/db/ (not onyx/auth/) because they are pure DB operations
that query PermissionGrant rows and update the User.effective_permissions
JSONB column. Keeping them here avoids circular imports when called from
other onyx/db/ modules such as users.py.
"""
from collections import defaultdict
from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.db.models import PermissionGrant
from onyx.db.models import User
from onyx.db.models import User__UserGroup
def recompute_user_permissions__no_commit(user_id: UUID, db_session: Session) -> None:
"""Recompute a single user's granted permissions from their group grants.
Stores only directly granted permissions — implication expansion
happens at read time via get_effective_permissions().
Does NOT commit — caller must commit the session.
"""
stmt = (
select(PermissionGrant.permission)
.join(
User__UserGroup,
PermissionGrant.group_id == User__UserGroup.user_group_id,
)
.where(
User__UserGroup.user_id == user_id,
PermissionGrant.is_deleted.is_(False),
)
)
rows = db_session.execute(stmt).scalars().all()
# sorted for consistent ordering in DB — easier to read when debugging
granted = sorted({p.value for p in rows})
db_session.execute(
update(User).where(User.id == user_id).values(effective_permissions=granted)
)
def recompute_permissions_for_group__no_commit(
group_id: int, db_session: Session
) -> None:
"""Recompute granted permissions for all users in a group.
Does NOT commit — caller must commit the session.
"""
user_ids: list[UUID] = list(
db_session.execute(
select(User__UserGroup.user_id).where(
User__UserGroup.user_group_id == group_id,
User__UserGroup.user_id.isnot(None),
)
)
.scalars()
.all()
)
if not user_ids:
return
# Single query to fetch ALL permissions for these users across ALL their
# groups (a user may belong to multiple groups with different grants).
rows = db_session.execute(
select(User__UserGroup.user_id, PermissionGrant.permission)
.join(
PermissionGrant,
PermissionGrant.group_id == User__UserGroup.user_group_id,
)
.where(
User__UserGroup.user_id.in_(user_ids),
PermissionGrant.is_deleted.is_(False),
)
).all()
# Group permissions by user; users with no grants get an empty set.
perms_by_user: dict[UUID, set[str]] = defaultdict(set)
for uid in user_ids:
perms_by_user[uid] # ensure every user has an entry
for uid, perm in rows:
perms_by_user[uid].add(perm.value)
for uid, perms in perms_by_user.items():
db_session.execute(
update(User)
.where(User.id == uid)
.values(effective_permissions=sorted(perms))
)

View File

@@ -7,6 +7,7 @@ from fastapi import HTTPException
from fastapi import UploadFile
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from sqlalchemy import func
from sqlalchemy.orm import Session
from starlette.background import BackgroundTasks
@@ -17,6 +18,7 @@ from onyx.configs.constants import FileOrigin
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.enums import UserFileStatus
from onyx.db.models import Project__UserFile
from onyx.db.models import User
from onyx.db.models import UserFile
@@ -34,9 +36,19 @@ class CategorizedFilesResult(BaseModel):
user_files: list[UserFile]
rejected_files: list[RejectedFile]
id_to_temp_id: dict[str, str]
# Filenames that should be stored but not indexed.
skip_indexing_filenames: set[str] = Field(default_factory=set)
# Allow SQLAlchemy ORM models inside this result container
model_config = ConfigDict(arbitrary_types_allowed=True)
@property
def indexable_files(self) -> list[UserFile]:
return [
uf
for uf in self.user_files
if (uf.name or "") not in self.skip_indexing_filenames
]
def build_hashed_file_key(file: UploadFile) -> str:
name_prefix = (file.filename or "")[:50]
@@ -70,6 +82,7 @@ def create_user_files(
)
if new_temp_id is not None:
id_to_temp_id[str(new_id)] = new_temp_id
should_skip = (file.filename or "") in categorized_files.skip_indexing
new_file = UserFile(
id=new_id,
user_id=user.id,
@@ -81,6 +94,7 @@ def create_user_files(
link_url=link_url,
content_type=file.content_type,
file_type=file.content_type,
status=UserFileStatus.SKIPPED if should_skip else UserFileStatus.PROCESSING,
last_accessed_at=datetime.datetime.now(datetime.timezone.utc),
)
# Persist the UserFile first to satisfy FK constraints for association table
@@ -98,6 +112,7 @@ def create_user_files(
user_files=user_files,
rejected_files=rejected_files,
id_to_temp_id=id_to_temp_id,
skip_indexing_filenames=categorized_files.skip_indexing,
)
@@ -123,6 +138,7 @@ def upload_files_to_user_files_with_indexing(
user_files = categorized_files_result.user_files
rejected_files = categorized_files_result.rejected_files
id_to_temp_id = categorized_files_result.id_to_temp_id
indexable_files = categorized_files_result.indexable_files
# Trigger per-file processing immediately for the current tenant
tenant_id = get_current_tenant_id()
for rejected_file in rejected_files:
@@ -134,12 +150,12 @@ def upload_files_to_user_files_with_indexing(
from onyx.background.task_utils import drain_processing_loop
background_tasks.add_task(drain_processing_loop, tenant_id)
for user_file in user_files:
for user_file in indexable_files:
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
else:
from onyx.background.celery.versioned_apps.client import app as client_app
for user_file in user_files:
for user_file in indexable_files:
task = client_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
@@ -155,6 +171,7 @@ def upload_files_to_user_files_with_indexing(
user_files=user_files,
rejected_files=rejected_files,
id_to_temp_id=id_to_temp_id,
skip_indexing_filenames=categorized_files_result.skip_indexing_filenames,
)

View File

@@ -19,6 +19,7 @@ from onyx.auth.schemas import UserRole
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.db.enums import AccountType
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__User
from onyx.db.models import Persona
@@ -27,8 +28,11 @@ from onyx.db.models import SamlAccount
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
def validate_user_role_update(
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
@@ -298,6 +302,7 @@ def _generate_slack_user(email: str) -> User:
email=email,
hashed_password=hashed_pass,
role=UserRole.SLACK_USER,
account_type=AccountType.BOT,
)
@@ -308,6 +313,7 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
# If the user is an external permissioned user, we update it to a slack user
if user.role == UserRole.EXT_PERM_USER:
user.role = UserRole.SLACK_USER
user.account_type = AccountType.BOT
db_session.commit()
return user
@@ -344,6 +350,7 @@ def _generate_ext_permissioned_user(email: str) -> User:
email=email,
hashed_password=hashed_pass,
role=UserRole.EXT_PERM_USER,
account_type=AccountType.EXT_PERM_USER,
)
@@ -375,6 +382,81 @@ def batch_add_ext_perm_user_if_not_exists(
return all_users
def assign_user_to_default_groups__no_commit(
db_session: Session,
user: User,
is_admin: bool = False,
) -> None:
"""Assign a newly created user to the appropriate default group.
Does NOT commit — callers must commit the session themselves so that
group assignment can be part of the same transaction as user creation.
Args:
is_admin: If True, assign to Admin default group; otherwise Basic.
Callers determine this from their own context (e.g. user_count,
admin email list, explicit choice). Defaults to False (Basic).
"""
if user.account_type in (
AccountType.BOT,
AccountType.EXT_PERM_USER,
AccountType.ANONYMOUS,
):
return
target_group_name = "Admin" if is_admin else "Basic"
default_group = (
db_session.query(UserGroup)
.filter(
UserGroup.name == target_group_name,
UserGroup.is_default.is_(True),
)
.first()
)
if default_group is None:
raise RuntimeError(
f"Default group '{target_group_name}' not found. "
f"Cannot assign user {user.email} to a group. "
f"Ensure the seed_default_groups migration has run."
)
# Check if the user is already in the group
existing = (
db_session.query(User__UserGroup)
.filter(
User__UserGroup.user_id == user.id,
User__UserGroup.user_group_id == default_group.id,
)
.first()
)
if existing is not None:
return
savepoint = db_session.begin_nested()
try:
db_session.add(
User__UserGroup(
user_id=user.id,
user_group_id=default_group.id,
)
)
db_session.flush()
except IntegrityError:
# Race condition: another transaction inserted this membership
# between our SELECT and INSERT. The savepoint isolates the failure
# so the outer transaction (user creation) stays intact.
savepoint.rollback()
return
from onyx.db.permissions import recompute_user_permissions__no_commit
recompute_user_permissions__no_commit(user.id, db_session)
logger.info(f"Assigned user {user.email} to default group '{default_group.name}'")
def delete_user_from_db(
user_to_delete: User,
db_session: Session,
@@ -421,13 +503,14 @@ def delete_user_from_db(
def batch_get_user_groups(
db_session: Session,
user_ids: list[UUID],
include_default: bool = False,
) -> dict[UUID, list[tuple[int, str]]]:
"""Fetch group memberships for a batch of users in a single query.
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
if not user_ids:
return {}
rows = db_session.execute(
stmt = (
select(
User__UserGroup.user_id,
UserGroup.id,
@@ -435,7 +518,11 @@ def batch_get_user_groups(
)
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
.where(User__UserGroup.user_id.in_(user_ids))
).all()
)
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
rows = db_session.execute(stmt).all()
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
for user_id, group_id, group_name in rows:

View File

@@ -932,7 +932,7 @@ class OpenSearchIndexClient(OpenSearchClient):
def search_for_document_ids(
self,
body: dict[str, Any],
search_type: OpenSearchSearchType = OpenSearchSearchType.DOCUMENT_IDS,
search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN,
) -> list[str]:
"""Searches the index and returns only document chunk IDs.

View File

@@ -60,8 +60,7 @@ class OpenSearchSearchType(str, Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
RANDOM = "random"
ID_RETRIEVAL = "id_retrieval"
DOCUMENT_IDS = "document_ids"
DOC_ID_RETRIEVAL = "doc_id_retrieval"
UNKNOWN = "unknown"

View File

@@ -6,6 +6,7 @@ import httpx
from opensearchpy import NotFoundError
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -738,6 +739,9 @@ class OpenSearchDocumentIndex(DocumentIndex):
_flush_chunks(current_chunks)
current_doc_id = doc_id
current_chunks = [chunk]
elif len(current_chunks) >= MAX_CHUNKS_PER_DOC_BATCH:
_flush_chunks(current_chunks)
current_chunks = [chunk]
else:
current_chunks.append(chunk)
@@ -924,7 +928,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_hits = self._client.search(
body=query_body,
search_pipeline_id=None,
search_type=OpenSearchSearchType.ID_RETRIEVAL,
search_type=OpenSearchSearchType.DOC_ID_RETRIEVAL,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(

View File

@@ -10,6 +10,7 @@ import httpx
from pydantic import BaseModel
from retry import retry
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
from onyx.configs.app_configs import RERANK_COUNT
from onyx.configs.chat_configs import DOC_TIME_DECAY
@@ -427,7 +428,9 @@ class VespaDocumentIndex(DocumentIndex):
new_document_id_to_original_document_id,
all_cleaned_doc_ids,
)
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
for chunk_batch in batch_generator(
cleaned_chunks, min(BATCH_SIZE, MAX_CHUNKS_PER_DOC_BATCH)
):
batch_index_vespa_chunks(
chunks=chunk_batch,
index_name=self._index_name,

View File

@@ -44,6 +44,7 @@ KNOWN_OPENPYXL_BUGS = [
"Value must be either numerical or a string containing a wildcard",
"File contains no valid workbook part",
"Unable to read workbook: could not read stylesheet from None",
"Colors must be aRGB hex values",
]

View File

@@ -15,6 +15,7 @@ PLAIN_TEXT_MIME_TYPE = "text/plain"
class OnyxMimeTypes:
IMAGE_MIME_TYPES = {"image/jpg", "image/jpeg", "image/png", "image/webp"}
CSV_MIME_TYPES = {"text/csv"}
TABULAR_MIME_TYPES = CSV_MIME_TYPES | {SPREADSHEET_MIME_TYPE}
TEXT_MIME_TYPES = {
PLAIN_TEXT_MIME_TYPE,
"text/markdown",
@@ -34,13 +35,12 @@ class OnyxMimeTypes:
PDF_MIME_TYPE,
WORD_PROCESSING_MIME_TYPE,
PRESENTATION_MIME_TYPE,
SPREADSHEET_MIME_TYPE,
"message/rfc822",
"application/epub+zip",
}
ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, TABULAR_MIME_TYPES
)
EXCLUDED_IMAGE_TYPES = {
@@ -53,6 +53,11 @@ class OnyxMimeTypes:
class OnyxFileExtensions:
TABULAR_EXTENSIONS = {
".csv",
".tsv",
".xlsx",
}
PLAIN_TEXT_EXTENSIONS = {
".txt",
".md",

View File

@@ -13,13 +13,14 @@ class ChatFileType(str, Enum):
DOC = "document"
# Plain text only contain the text
PLAIN_TEXT = "plain_text"
CSV = "csv"
# Tabular data files (CSV, XLSX)
TABULAR = "tabular"
def is_text_file(self) -> bool:
return self in (
ChatFileType.PLAIN_TEXT,
ChatFileType.DOC,
ChatFileType.CSV,
ChatFileType.TABULAR,
)

View File

@@ -1,4 +1,3 @@
from onyx.configs.app_configs import HOOK_ENABLED
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from shared_configs.configs import MULTI_TENANT
@@ -7,10 +6,7 @@ from shared_configs.configs import MULTI_TENANT
def require_hook_enabled() -> None:
"""FastAPI dependency that gates all hook management endpoints.
Hooks are only available in single-tenant / self-hosted deployments with
HOOK_ENABLED=true explicitly set. Two layers of protection:
1. MULTI_TENANT check — rejects even if HOOK_ENABLED is accidentally set true
2. HOOK_ENABLED flag — explicit opt-in by the operator
Hooks are only available in single-tenant / self-hosted EE deployments.
Use as: Depends(require_hook_enabled)
"""
@@ -19,8 +15,3 @@ def require_hook_enabled() -> None:
OnyxErrorCode.SINGLE_TENANT_ONLY,
"Hooks are not available in multi-tenant deployments",
)
if not HOOK_ENABLED:
raise OnyxError(
OnyxErrorCode.ENV_VAR_GATED,
"Hooks are not enabled. Set HOOK_ENABLED=true to enable.",
)

View File

@@ -1,79 +1,22 @@
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
"""CE hook executor.
Usage (Celery tasks and FastAPI handlers):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
response_type=QueryProcessingResponse,
)
HookSkipped and HookSoftFailed are real classes kept here because
process_message.py (CE code) uses isinstance checks against them.
if isinstance(result, HookSkipped):
# no active hook configured — continue with original behavior
...
elif isinstance(result, HookSoftFailed):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is a validated Pydantic model instance (response_type)
...
is_reachable update policy
--------------------------
``is_reachable`` on the Hook row is updated selectively — only when the outcome
carries meaningful signal about physical reachability:
NetworkError (DNS, connection refused) → False (cannot reach the server)
HTTP 401 / 403 → False (api_key revoked or invalid)
TimeoutException → None (server may be slow, skip write)
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
Unknown exception → None (no signal, skip write)
Non-JSON / non-dict response → None (server responded, skip write)
Success (2xx, valid dict) → True (confirmed reachable)
None means "leave the current value unchanged" — no DB round-trip is made.
DB session design
-----------------
The executor uses three sessions:
1. Caller's session (db_session) — used only for the hook lookup read. All
needed fields are extracted from the Hook object before the HTTP call, so
the caller's session is not held open during the external HTTP request.
2. Log session — a separate short-lived session opened after the HTTP call
completes to write the HookExecutionLog row on failure. Success runs are
not recorded. Committed independently of everything else.
3. Reachable session — a second short-lived session to update is_reachable on
the Hook. Kept separate from the log session so a concurrent hook deletion
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
prevent the execution log from being written. This update is best-effort.
execute_hook is the public entry point. It dispatches to _execute_hook_impl
via fetch_versioned_implementation so that:
- CE: onyx.hooks.executor._execute_hook_impl → no-op, returns HookSkipped()
- EE: ee.onyx.hooks.executor._execute_hook_impl → real HTTP call
"""
import json
import time
from typing import Any
from typing import TypeVar
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.hook import create_hook_execution_log__no_commit
from onyx.db.hook import get_non_deleted_hook_by_hook_point
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.utils.logger import setup_logger
logger = setup_logger()
from onyx.utils.variable_functionality import fetch_versioned_implementation
class HookSkipped:
@@ -87,277 +30,15 @@ class HookSoftFailed:
T = TypeVar("T", bound=BaseModel)
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
class _HttpOutcome(BaseModel):
"""Structured result of an HTTP hook call, returned by _process_response."""
is_success: bool
updated_is_reachable: (
bool | None
) # True/False = write to DB, None = unchanged (skip write)
status_code: int | None
error_message: str | None
response_payload: dict[str, Any] | None
def _lookup_hook(
db_session: Session,
hook_point: HookPoint,
) -> Hook | HookSkipped:
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
No HTTP call is made and no DB writes are performed for any HookSkipped path.
There is nothing to log and no reachability information to update.
"""
if not HOOKS_AVAILABLE:
return HookSkipped()
hook = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if hook is None or not hook.is_active:
return HookSkipped()
if not hook.endpoint_url:
return HookSkipped()
return hook
def _process_response(
def _execute_hook_impl(
*,
response: httpx.Response | None,
exc: Exception | None,
timeout: float,
) -> _HttpOutcome:
"""Process the result of an HTTP call and return a structured outcome.
Called after the client.post() try/except. If post() raised, exc is set and
response is None. Otherwise response is set and exc is None. Handles
raise_for_status(), JSON decoding, and the dict shape check.
"""
if exc is not None:
if isinstance(exc, httpx.NetworkError):
msg = f"Hook network error (endpoint unreachable): {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False,
status_code=None,
error_message=msg,
response_payload=None,
)
if isinstance(exc, httpx.TimeoutException):
msg = f"Hook timed out after {timeout}s: {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # timeout doesn't indicate unreachability
status_code=None,
error_message=msg,
response_payload=None,
)
msg = f"Hook call failed: {exc}"
logger.exception(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # unknown error — don't make assumptions
status_code=None,
error_message=msg,
response_payload=None,
)
if response is None:
raise ValueError(
"exactly one of response or exc must be non-None; both are None"
)
status_code = response.status_code
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
logger.warning(msg, exc_info=e)
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
# so the operator knows to update it. All other HTTP errors keep is_reachable
# as-is (server is up, the request just failed for application reasons).
auth_failed = e.response.status_code in (401, 403)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False if auth_failed else None,
status_code=status_code,
error_message=msg,
response_payload=None,
)
try:
response_payload = response.json()
except (json.JSONDecodeError, httpx.DecodingError) as e:
msg = f"Hook returned non-JSON response: {e}"
logger.warning(msg, exc_info=e)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
if not isinstance(response_payload, dict):
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
logger.warning(msg)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
return _HttpOutcome(
is_success=True,
updated_is_reachable=True,
status_code=status_code,
error_message=None,
response_payload=response_payload,
)
def _persist_result(
*,
hook_id: int,
outcome: _HttpOutcome,
duration_ms: int,
) -> None:
"""Write the execution log on failure and optionally update is_reachable, each
in its own session so a failure in one does not affect the other."""
# Only write the execution log on failure — success runs are not recorded.
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
# deleted between the initial lookup and here).
if not outcome.is_success:
try:
with get_session_with_current_tenant() as log_session:
create_hook_execution_log__no_commit(
db_session=log_session,
hook_id=hook_id,
is_success=False,
error_message=outcome.error_message,
status_code=outcome.status_code,
duration_ms=duration_ms,
)
log_session.commit()
except Exception:
logger.exception(
f"Failed to persist hook execution log for hook_id={hook_id}"
)
# Update is_reachable separately — best-effort, non-critical.
# None means the value is unchanged (set by the caller to skip the no-op write).
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
# concurrently deleted, so keep this isolated from the log write above.
if outcome.updated_is_reachable is not None:
try:
with get_session_with_current_tenant() as reachable_session:
update_hook__no_commit(
db_session=reachable_session,
hook_id=hook_id,
is_reachable=outcome.updated_is_reachable,
)
reachable_session.commit()
except Exception:
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def _execute_hook_inner(
hook: Hook,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
"active hooks without an endpoint_url must be rejected by _lookup_hook"
)
start = time.monotonic()
response: httpx.Response | None = None
exc: Exception | None = None
try:
api_key: str | None = (
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
)
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against response_type.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: T | None = None
if outcome.is_success and outcome.response_payload is not None:
try:
validated_model = response_type.model_validate(outcome.response_payload)
except ValidationError as e:
msg = (
f"Hook response failed validation against {response_type.__name__}: {e}"
)
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
outcome = outcome.model_copy(update={"updated_is_reachable": None})
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
if not outcome.is_success:
if fail_strategy == HookFailStrategy.HARD:
raise OnyxError(
OnyxErrorCode.HOOK_EXECUTION_FAILED,
outcome.error_message or "Hook execution failed.",
)
logger.warning(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if validated_model is None:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"validated_model is None for successful hook call (hook_id={hook_id})",
)
return validated_model
db_session: Session, # noqa: ARG001
hook_point: HookPoint, # noqa: ARG001
payload: dict[str, Any], # noqa: ARG001
response_type: type[T], # noqa: ARG001
) -> T | HookSkipped | HookSoftFailed:
"""CE no-op — hooks are not available without EE."""
return HookSkipped()
def execute_hook(
@@ -367,25 +48,15 @@ def execute_hook(
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously.
"""Execute the hook for the given hook point.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
Dispatches to the versioned implementation so EE gets the real executor
and CE gets the no-op stub, without any changes at the call site.
"""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
fail_strategy = hook.fail_strategy
hook_id = hook.id
try:
return _execute_hook_inner(hook, payload, response_type)
except Exception:
if fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
)
return HookSoftFailed()
raise
impl = fetch_versioned_implementation("onyx.hooks.executor", "_execute_hook_impl")
return impl(
db_session=db_session,
hook_point=hook_point,
payload=payload,
response_type=response_type,
)

View File

@@ -1,5 +0,0 @@
from onyx.configs.app_configs import HOOK_ENABLED
from shared_configs.configs import MULTI_TENANT
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT

View File

@@ -19,7 +19,8 @@ from onyx.db.document import update_docs_updated_at__no_commit
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import ChunkEnrichmentContext
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
@@ -85,14 +86,21 @@ class DocumentIndexingBatchAdapter:
) as transaction:
yield transaction
def build_metadata_aware_chunks(
def prepare_enrichment(
self,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
context: DocumentBatchPrepareContext,
) -> BuildMetadataAwareChunksResult:
"""Enrich chunks with access, document sets, boosts, token counts, and hierarchy."""
tenant_id: str,
chunks: list[DocAwareChunk],
) -> "DocumentChunkEnricher":
"""Do all DB lookups once and return a per-chunk enricher."""
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_new_chunk_cnt: dict[str, int] = {
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
no_access = DocumentAccess.build(
user_emails=[],
@@ -102,67 +110,30 @@ class DocumentIndexingBatchAdapter:
is_public=False,
)
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_access_info = get_access_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
doc_id_to_document_set = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
return DocumentChunkEnricher(
doc_id_to_access_info=get_access_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
}
doc_id_to_previous_chunk_cnt: dict[str, int] = {
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
db_session=self.db_session,
)
}
doc_id_to_new_chunk_cnt: dict[str, int] = {
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks_with_embeddings:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
# Get ancestor hierarchy node IDs for each document
doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents(
context.updatable_docs, tenant_id
)
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=doc_id_to_access_info.get(chunk.source_document.id, no_access),
document_sets=set(
doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
context.id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in context.id_to_boost_map
else DEFAULT_BOOST
),
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
ancestor_hierarchy_node_ids=doc_id_to_ancestor_ids[
chunk.source_document.id
],
)
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
return BuildMetadataAwareChunksResult(
chunks=access_aware_chunks,
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=doc_id_to_new_chunk_cnt,
user_file_id_to_raw_text={},
user_file_id_to_token_count={},
),
doc_id_to_document_set={
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
},
doc_id_to_ancestor_ids=self._get_ancestor_ids_for_documents(
context.updatable_docs, tenant_id
),
id_to_boost_map=context.id_to_boost_map,
doc_id_to_previous_chunk_cnt={
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
db_session=self.db_session,
)
},
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
no_access=no_access,
tenant_id=tenant_id,
)
def _get_ancestor_ids_for_documents(
@@ -203,7 +174,7 @@ class DocumentIndexingBatchAdapter:
context: DocumentBatchPrepareContext,
updatable_chunk_data: list[UpdatableChunkData],
filtered_documents: list[Document],
result: BuildMetadataAwareChunksResult,
enrichment: ChunkEnrichmentContext,
) -> None:
"""Finalize DB updates, store plaintext, and mark docs as indexed."""
updatable_ids = [doc.id for doc in context.updatable_docs]
@@ -227,7 +198,7 @@ class DocumentIndexingBatchAdapter:
update_docs_chunk_count__no_commit(
document_ids=updatable_ids,
doc_id_to_chunk_count=result.doc_id_to_new_chunk_cnt,
doc_id_to_chunk_count=enrichment.doc_id_to_new_chunk_cnt,
db_session=self.db_session,
)
@@ -249,3 +220,52 @@ class DocumentIndexingBatchAdapter:
)
self.db_session.commit()
class DocumentChunkEnricher:
"""Pre-computed metadata for per-chunk enrichment of connector documents."""
def __init__(
self,
doc_id_to_access_info: dict[str, DocumentAccess],
doc_id_to_document_set: dict[str, list[str]],
doc_id_to_ancestor_ids: dict[str, list[int]],
id_to_boost_map: dict[str, int],
doc_id_to_previous_chunk_cnt: dict[str, int],
doc_id_to_new_chunk_cnt: dict[str, int],
no_access: DocumentAccess,
tenant_id: str,
) -> None:
self._doc_id_to_access_info = doc_id_to_access_info
self._doc_id_to_document_set = doc_id_to_document_set
self._doc_id_to_ancestor_ids = doc_id_to_ancestor_ids
self._id_to_boost_map = id_to_boost_map
self._no_access = no_access
self._tenant_id = tenant_id
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk:
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=self._doc_id_to_access_info.get(
chunk.source_document.id, self._no_access
),
document_sets=set(
self._doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
self._id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in self._id_to_boost_map
else DEFAULT_BOOST
),
tenant_id=self._tenant_id,
aggregated_chunk_boost_factor=score,
ancestor_hierarchy_node_ids=self._doc_id_to_ancestor_ids[
chunk.source_document.id
],
)

View File

@@ -1,6 +1,9 @@
from __future__ import annotations
import contextlib
import datetime
import time
from collections import defaultdict
from collections.abc import Generator
from uuid import UUID
@@ -24,7 +27,8 @@ from onyx.db.user_file import fetch_persona_ids_for_user_files
from onyx.db.user_file import fetch_user_project_ids_for_user_files
from onyx.file_store.utils import store_user_file_plaintext
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import ChunkEnrichmentContext
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
@@ -102,13 +106,20 @@ class UserFileIndexingAdapter:
f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts for user files: {[doc.id for doc in documents]}"
)
def build_metadata_aware_chunks(
def prepare_enrichment(
self,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
context: DocumentBatchPrepareContext,
) -> BuildMetadataAwareChunksResult:
tenant_id: str,
chunks: list[DocAwareChunk],
) -> UserFileChunkEnricher:
"""Do all DB lookups and pre-compute file metadata from chunks."""
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_new_chunk_cnt: dict[str, int] = defaultdict(int)
content_by_file: dict[str, list[str]] = defaultdict(list)
for chunk in chunks:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
content_by_file[chunk.source_document.id].append(chunk.content)
no_access = DocumentAccess.build(
user_emails=[],
@@ -118,7 +129,6 @@ class UserFileIndexingAdapter:
is_public=False,
)
updatable_ids = [doc.id for doc in context.updatable_docs]
user_file_id_to_project_ids = fetch_user_project_ids_for_user_files(
user_file_ids=updatable_ids,
db_session=self.db_session,
@@ -139,17 +149,6 @@ class UserFileIndexingAdapter:
)
}
user_file_id_to_new_chunk_cnt: dict[str, int] = {
user_file_id: len(
[
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == user_file_id
]
)
for user_file_id in updatable_ids
}
# Initialize tokenizer used for token count calculation
try:
llm = get_default_llm()
@@ -164,15 +163,9 @@ class UserFileIndexingAdapter:
user_file_id_to_raw_text: dict[str, str] = {}
user_file_id_to_token_count: dict[str, int | None] = {}
for user_file_id in updatable_ids:
user_file_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == user_file_id
]
if user_file_chunks:
combined_content = " ".join(
[chunk.content for chunk in user_file_chunks]
)
contents = content_by_file.get(user_file_id)
if contents:
combined_content = " ".join(contents)
user_file_id_to_raw_text[str(user_file_id)] = combined_content
token_count: int = (
count_tokens(combined_content, llm_tokenizer)
@@ -184,28 +177,16 @@ class UserFileIndexingAdapter:
user_file_id_to_raw_text[str(user_file_id)] = ""
user_file_id_to_token_count[str(user_file_id)] = None
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=user_file_id_to_access.get(chunk.source_document.id, no_access),
document_sets=set(),
user_project=user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
boost=DEFAULT_BOOST,
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
)
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
return BuildMetadataAwareChunksResult(
chunks=access_aware_chunks,
return UserFileChunkEnricher(
user_file_id_to_access=user_file_id_to_access,
user_file_id_to_project_ids=user_file_id_to_project_ids,
user_file_id_to_persona_ids=user_file_id_to_persona_ids,
doc_id_to_previous_chunk_cnt=user_file_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=user_file_id_to_new_chunk_cnt,
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
user_file_id_to_raw_text=user_file_id_to_raw_text,
user_file_id_to_token_count=user_file_id_to_token_count,
no_access=no_access,
tenant_id=tenant_id,
)
def _notify_assistant_owners_if_files_ready(
@@ -249,8 +230,9 @@ class UserFileIndexingAdapter:
context: DocumentBatchPrepareContext,
updatable_chunk_data: list[UpdatableChunkData], # noqa: ARG002
filtered_documents: list[Document], # noqa: ARG002
result: BuildMetadataAwareChunksResult,
enrichment: ChunkEnrichmentContext,
) -> None:
assert isinstance(enrichment, UserFileChunkEnricher)
user_file_ids = [doc.id for doc in context.updatable_docs]
user_files = (
@@ -266,8 +248,10 @@ class UserFileIndexingAdapter:
user_file.last_project_sync_at = datetime.datetime.now(
datetime.timezone.utc
)
user_file.chunk_count = result.doc_id_to_new_chunk_cnt[str(user_file.id)]
user_file.token_count = result.user_file_id_to_token_count[
user_file.chunk_count = enrichment.doc_id_to_new_chunk_cnt.get(
str(user_file.id), 0
)
user_file.token_count = enrichment.user_file_id_to_token_count[
str(user_file.id)
]
@@ -279,8 +263,54 @@ class UserFileIndexingAdapter:
# Store the plaintext in the file store for faster retrieval
# NOTE: this creates its own session to avoid committing the overall
# transaction.
for user_file_id, raw_text in result.user_file_id_to_raw_text.items():
for user_file_id, raw_text in enrichment.user_file_id_to_raw_text.items():
store_user_file_plaintext(
user_file_id=UUID(user_file_id),
plaintext_content=raw_text,
)
class UserFileChunkEnricher:
"""Pre-computed metadata for per-chunk enrichment of user-uploaded files."""
def __init__(
self,
user_file_id_to_access: dict[str, DocumentAccess],
user_file_id_to_project_ids: dict[str, list[int]],
user_file_id_to_persona_ids: dict[str, list[int]],
doc_id_to_previous_chunk_cnt: dict[str, int],
doc_id_to_new_chunk_cnt: dict[str, int],
user_file_id_to_raw_text: dict[str, str],
user_file_id_to_token_count: dict[str, int | None],
no_access: DocumentAccess,
tenant_id: str,
) -> None:
self._user_file_id_to_access = user_file_id_to_access
self._user_file_id_to_project_ids = user_file_id_to_project_ids
self._user_file_id_to_persona_ids = user_file_id_to_persona_ids
self._no_access = no_access
self._tenant_id = tenant_id
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
self.user_file_id_to_raw_text = user_file_id_to_raw_text
self.user_file_id_to_token_count = user_file_id_to_token_count
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk:
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=self._user_file_id_to_access.get(
chunk.source_document.id, self._no_access
),
document_sets=set(),
user_project=self._user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=self._user_file_id_to_persona_ids.get(
chunk.source_document.id, []
),
boost=DEFAULT_BOOST,
tenant_id=self._tenant_id,
aggregated_chunk_boost_factor=score,
)

View File

@@ -0,0 +1,89 @@
import pickle
import shutil
import tempfile
from collections.abc import Iterator
from pathlib import Path
from onyx.indexing.models import IndexChunk
class ChunkBatchStore:
"""Manages serialization of embedded chunks to a temporary directory.
Owns the temp directory lifetime and provides save/load/stream/scrub
operations.
Use as a context manager to ensure cleanup::
with ChunkBatchStore() as store:
store.save(chunks, batch_idx=0)
for chunk in store.stream():
...
"""
_EXT = ".pkl"
def __init__(self) -> None:
self._tmpdir: Path | None = None
# -- context manager -----------------------------------------------------
def __enter__(self) -> "ChunkBatchStore":
self._tmpdir = Path(tempfile.mkdtemp(prefix="onyx_embeddings_"))
return self
def __exit__(self, *_exc: object) -> None:
if self._tmpdir is not None:
shutil.rmtree(self._tmpdir, ignore_errors=True)
self._tmpdir = None
@property
def _dir(self) -> Path:
assert self._tmpdir is not None, "ChunkBatchStore used outside context manager"
return self._tmpdir
# -- storage primitives --------------------------------------------------
def save(self, chunks: list[IndexChunk], batch_idx: int) -> None:
"""Serialize a batch of embedded chunks to disk."""
with open(self._dir / f"batch_{batch_idx}{self._EXT}", "wb") as f:
pickle.dump(chunks, f)
def _load(self, batch_file: Path) -> list[IndexChunk]:
"""Deserialize a batch of embedded chunks from a file."""
with open(batch_file, "rb") as f:
return pickle.load(f)
def _batch_files(self) -> list[Path]:
"""Return batch files sorted by numeric index."""
return sorted(
self._dir.glob(f"batch_*{self._EXT}"),
key=lambda p: int(p.stem.removeprefix("batch_")),
)
# -- higher-level operations ---------------------------------------------
def stream(self) -> Iterator[IndexChunk]:
"""Yield all chunks across all batch files.
Each call returns a fresh generator, so the data can be iterated
multiple times (e.g. once per document index).
"""
for batch_file in self._batch_files():
yield from self._load(batch_file)
def scrub_failed_docs(self, failed_doc_ids: set[str]) -> None:
"""Remove chunks belonging to *failed_doc_ids* from all batch files.
When a document fails embedding in batch N, earlier batches may
already contain successfully embedded chunks for that document.
This ensures the output is all-or-nothing per document.
"""
for batch_file in self._batch_files():
batch_chunks = self._load(batch_file)
cleaned = [
c for c in batch_chunks if c.source_document.id not in failed_doc_ids
]
if len(cleaned) != len(batch_chunks):
with open(batch_file, "wb") as f:
pickle.dump(cleaned, f)

View File

@@ -1,5 +1,8 @@
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Protocol
from pydantic import BaseModel
@@ -9,6 +12,7 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
@@ -43,10 +47,12 @@ from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_store.file_store import get_default_file_store
from onyx.indexing.chunk_batch_store import ChunkBatchStore
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import embed_chunks_with_failure_handling
from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexingBatchAdapter
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
@@ -63,6 +69,7 @@ from onyx.natural_language_processing.utils import tokenizer_trim_middle
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -91,6 +98,20 @@ class IndexingPipelineResult(BaseModel):
failures: list[ConnectorFailure]
@classmethod
def empty(cls, total_docs: int) -> "IndexingPipelineResult":
return cls(
new_docs=0,
total_docs=total_docs,
total_chunks=0,
failures=[],
)
class ChunkEmbeddingResult(BaseModel):
successful_chunk_ids: list[tuple[int, str]] # (chunk_id, document_id)
connector_failures: list[ConnectorFailure]
class IndexingPipelineProtocol(Protocol):
def __call__(
@@ -139,6 +160,110 @@ def _upsert_documents_in_db(
)
def _get_failed_doc_ids(failures: list[ConnectorFailure]) -> set[str]:
"""Extract document IDs from a list of connector failures."""
return {f.failed_document.document_id for f in failures if f.failed_document}
def _embed_chunks_to_store(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
tenant_id: str,
request_id: str | None,
store: ChunkBatchStore,
) -> ChunkEmbeddingResult:
"""Embed chunks in batches, spilling each batch to *store*.
If a document fails embedding in any batch, its chunks are excluded from
all batches (including earlier ones already written) so that the output
is all-or-nothing per document.
"""
successful_chunk_ids: list[tuple[int, str]] = []
all_embedding_failures: list[ConnectorFailure] = []
# Track failed doc IDs across all batches so that a failure in batch N
# causes chunks for that doc to be skipped in batch N+1 and stripped
# from earlier batches.
all_failed_doc_ids: set[str] = set()
for batch_idx, chunk_batch in enumerate(
batch_generator(chunks, MAX_CHUNKS_PER_DOC_BATCH)
):
# Skip chunks belonging to documents that failed in earlier batches.
chunk_batch = [
c for c in chunk_batch if c.source_document.id not in all_failed_doc_ids
]
if not chunk_batch:
continue
logger.debug(f"Embedding batch {batch_idx}: {len(chunk_batch)} chunks")
chunks_with_embeddings, embedding_failures = embed_chunks_with_failure_handling(
chunks=chunk_batch,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
)
all_embedding_failures.extend(embedding_failures)
all_failed_doc_ids.update(_get_failed_doc_ids(embedding_failures))
# Only keep successfully embedded chunks for non-failed docs.
chunks_with_embeddings = [
c
for c in chunks_with_embeddings
if c.source_document.id not in all_failed_doc_ids
]
successful_chunk_ids.extend(
(c.chunk_id, c.source_document.id) for c in chunks_with_embeddings
)
store.save(chunks_with_embeddings, batch_idx)
del chunks_with_embeddings
# Scrub earlier batches for docs that failed in later batches.
if all_failed_doc_ids:
store.scrub_failed_docs(all_failed_doc_ids)
successful_chunk_ids = [
(chunk_id, doc_id)
for chunk_id, doc_id in successful_chunk_ids
if doc_id not in all_failed_doc_ids
]
return ChunkEmbeddingResult(
successful_chunk_ids=successful_chunk_ids,
connector_failures=all_embedding_failures,
)
@contextmanager
def embed_and_stream(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
tenant_id: str,
request_id: str | None,
) -> Generator[tuple[ChunkEmbeddingResult, ChunkBatchStore], None, None]:
"""Embed chunks to disk and yield a ``(result, store)`` pair.
The store owns the temp directory — files are cleaned up when the context
manager exits.
Usage::
with embed_and_stream(chunks, embedder, tenant_id, req_id) as (result, store):
for chunk in store.stream():
...
"""
with ChunkBatchStore() as store:
result = _embed_chunks_to_store(
chunks=chunks,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
store=store,
)
yield result, store
def get_doc_ids_to_update(
documents: list[Document], db_docs: list[DBDocument]
) -> list[Document]:
@@ -637,6 +762,29 @@ def add_contextual_summaries(
return chunks
def _verify_indexing_completeness(
insertion_records: list[DocumentInsertionRecord],
write_failures: list[ConnectorFailure],
embedding_failed_doc_ids: set[str],
updatable_ids: list[str],
document_index_name: str,
) -> None:
"""Verify that every updatable document was either indexed or reported as failed."""
all_returned_doc_ids = (
{r.document_id for r in insertion_records}
| {f.failed_document.document_id for f in write_failures if f.failed_document}
| embedding_failed_doc_ids
)
if all_returned_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
f"This should never happen. "
f"This occured for document index {document_index_name}"
)
@log_function_time(debug_only=True)
def index_doc_batch(
*,
@@ -672,12 +820,7 @@ def index_doc_batch(
filtered_documents = filter_fnc(document_batch)
context = adapter.prepare(filtered_documents, ignore_time_skip)
if not context:
return IndexingPipelineResult(
new_docs=0,
total_docs=len(filtered_documents),
total_chunks=0,
failures=[],
)
return IndexingPipelineResult.empty(len(filtered_documents))
# Convert documents to IndexingDocument objects with processed section
# logger.debug("Processing image sections")
@@ -716,117 +859,99 @@ def index_doc_batch(
)
logger.debug("Starting embedding")
chunks_with_embeddings, embedding_failures = (
embed_chunks_with_failure_handling(
chunks=chunks,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
)
if chunks
else ([], [])
)
with embed_and_stream(chunks, embedder, tenant_id, request_id) as (
embedding_result,
chunk_store,
):
updatable_ids = [doc.id for doc in context.updatable_docs]
updatable_chunk_data = [
UpdatableChunkData(
chunk_id=chunk_id,
document_id=document_id,
boost_score=1.0,
)
for chunk_id, document_id in embedding_result.successful_chunk_ids
]
chunk_content_scores = [1.0] * len(chunks_with_embeddings)
updatable_ids = [doc.id for doc in context.updatable_docs]
updatable_chunk_data = [
UpdatableChunkData(
chunk_id=chunk.chunk_id,
document_id=chunk.source_document.id,
boost_score=score,
)
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
]
# Acquires a lock on the documents so that no other process can modify them
# NOTE: don't need to acquire till here, since this is when the actual race condition
# with Vespa can occur.
with adapter.lock_context(context.updatable_docs):
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.
# we still write data here for the immediate and most likely correct sync, but
# to resolve this, an update of the last modified field at the end of this loop
# always triggers a final metadata sync via the celery queue
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=chunks_with_embeddings,
chunk_content_scores=chunk_content_scores,
tenant_id=tenant_id,
context=context,
embedding_failed_doc_ids = _get_failed_doc_ids(
embedding_result.connector_failures
)
short_descriptor_list = [chunk.to_short_descriptor() for chunk in result.chunks]
short_descriptor_log = str(short_descriptor_list)[:1024]
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
# Filter to only successfully embedded chunks so
# doc_id_to_new_chunk_cnt reflects what's actually written to Vespa.
embedded_chunks = [
c for c in chunks if c.source_document.id not in embedding_failed_doc_ids
]
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = None
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = None
for document_index in document_indices:
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
# in this set
(
insertion_records,
vector_db_write_failures,
) = write_chunks_to_vector_db_with_backoff(
document_index=document_index,
chunks=result.chunks,
index_batch_params=IndexBatchParams(
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
tenant_id=tenant_id,
large_chunks_enabled=chunker.enable_large_chunks,
),
# Acquires a lock on the documents so that no other process can modify
# them. Not needed until here, since this is when the actual race
# condition with vector db can occur.
with adapter.lock_context(context.updatable_docs):
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=tenant_id,
chunks=embedded_chunks,
)
all_returned_doc_ids: set[str] = (
{record.document_id for record in insertion_records}
.union(
{
record.failed_document.document_id
for record in vector_db_write_failures
if record.failed_document
}
)
.union(
{
record.failed_document.document_id
for record in embedding_failures
if record.failed_document
}
)
index_batch_params = IndexBatchParams(
doc_id_to_previous_chunk_cnt=enricher.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=enricher.doc_id_to_new_chunk_cnt,
tenant_id=tenant_id,
large_chunks_enabled=chunker.enable_large_chunks,
)
if all_returned_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
"This should never happen."
f"This occured for document index {document_index.__class__.__name__}"
)
# We treat the first document index we got as the primary one used
# for reporting the state of indexing.
if primary_doc_idx_insertion_records is None:
primary_doc_idx_insertion_records = insertion_records
if primary_doc_idx_vector_db_write_failures is None:
primary_doc_idx_vector_db_write_failures = vector_db_write_failures
adapter.post_index(
context=context,
updatable_chunk_data=updatable_chunk_data,
filtered_documents=filtered_documents,
result=result,
)
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = (
None
)
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = (
None
)
for document_index in document_indices:
def _enriched_stream() -> Iterator[DocMetadataAwareIndexChunk]:
for chunk in chunk_store.stream():
yield enricher.enrich_chunk(chunk, 1.0)
insertion_records, write_failures = (
write_chunks_to_vector_db_with_backoff(
document_index=document_index,
make_chunks=_enriched_stream,
index_batch_params=index_batch_params,
)
)
_verify_indexing_completeness(
insertion_records=insertion_records,
write_failures=write_failures,
embedding_failed_doc_ids=embedding_failed_doc_ids,
updatable_ids=updatable_ids,
document_index_name=document_index.__class__.__name__,
)
# We treat the first document index we got as the primary one used
# for reporting the state of indexing.
if primary_doc_idx_insertion_records is None:
primary_doc_idx_insertion_records = insertion_records
if primary_doc_idx_vector_db_write_failures is None:
primary_doc_idx_vector_db_write_failures = write_failures
adapter.post_index(
context=context,
updatable_chunk_data=updatable_chunk_data,
filtered_documents=filtered_documents,
enrichment=enricher,
)
assert primary_doc_idx_insertion_records is not None
assert primary_doc_idx_vector_db_write_failures is not None
return IndexingPipelineResult(
new_docs=len(
[r for r in primary_doc_idx_insertion_records if not r.already_existed]
new_docs=sum(
1 for r in primary_doc_idx_insertion_records if not r.already_existed
),
total_docs=len(filtered_documents),
total_chunks=len(chunks_with_embeddings),
failures=primary_doc_idx_vector_db_write_failures + embedding_failures,
total_chunks=len(embedding_result.successful_chunk_ids),
failures=primary_doc_idx_vector_db_write_failures
+ embedding_result.connector_failures,
)

View File

@@ -235,12 +235,16 @@ class UpdatableChunkData(BaseModel):
boost_score: float
class BuildMetadataAwareChunksResult(BaseModel):
chunks: list[DocMetadataAwareIndexChunk]
class ChunkEnrichmentContext(Protocol):
"""Returned by prepare_enrichment. Holds pre-computed metadata lookups
and provides per-chunk enrichment."""
doc_id_to_previous_chunk_cnt: dict[str, int]
doc_id_to_new_chunk_cnt: dict[str, int]
user_file_id_to_raw_text: dict[str, str]
user_file_id_to_token_count: dict[str, int | None]
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk: ...
class IndexingBatchAdapter(Protocol):
@@ -254,18 +258,24 @@ class IndexingBatchAdapter(Protocol):
) -> Generator[TransactionalContext, None, None]:
"""Provide a transaction/row-lock context for critical updates."""
def build_metadata_aware_chunks(
def prepare_enrichment(
self,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
context: "DocumentBatchPrepareContext",
) -> BuildMetadataAwareChunksResult: ...
tenant_id: str,
chunks: list[DocAwareChunk],
) -> ChunkEnrichmentContext:
"""Prepare per-chunk enrichment data (access, document sets, boost, etc.).
Precondition: ``chunks`` have already been through the embedding step
(i.e. they are ``IndexChunk`` instances with populated embeddings,
passed here as the base ``DocAwareChunk`` type).
"""
...
def post_index(
self,
context: "DocumentBatchPrepareContext",
updatable_chunk_data: list[UpdatableChunkData],
filtered_documents: list[Document],
result: BuildMetadataAwareChunksResult,
enrichment: ChunkEnrichmentContext,
) -> None: ...

View File

@@ -1,6 +1,9 @@
import time
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Iterable
from http import HTTPStatus
from itertools import chain
from itertools import groupby
import httpx
@@ -28,22 +31,22 @@ def _log_insufficient_storage_error(e: Exception) -> None:
def write_chunks_to_vector_db_with_backoff(
document_index: DocumentIndex,
chunks: list[DocMetadataAwareIndexChunk],
make_chunks: Callable[[], Iterable[DocMetadataAwareIndexChunk]],
index_batch_params: IndexBatchParams,
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
goes document by document to isolate the failure(s).
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
vector DB interface assumes that all chunks for a single document are present.
vector DB interface assumes that all chunks for a single document are present. The
chunks must also be in contiguous batches
"""
# first try to write the chunks to the vector db
try:
return (
list(
document_index.index(
chunks=chunks,
chunks=make_chunks(),
index_batch_params=index_batch_params,
)
),
@@ -60,14 +63,23 @@ def write_chunks_to_vector_db_with_backoff(
# wait a couple seconds just to give the vector db a chance to recover
time.sleep(2)
# try writing each doc one by one
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
for chunk in chunks:
chunks_for_docs[chunk.source_document.id].append(chunk)
insertion_records: list[DocumentInsertionRecord] = []
failures: list[ConnectorFailure] = []
for doc_id, chunks_for_doc in chunks_for_docs.items():
def key(chunk: DocMetadataAwareIndexChunk) -> str:
return chunk.source_document.id
seen_doc_ids: set[str] = set()
for doc_id, chunks_for_doc in groupby(make_chunks(), key=key):
if doc_id in seen_doc_ids:
raise RuntimeError(
f"Doc chunks are not arriving in order. Current doc_id={doc_id}, seen_doc_ids={list(seen_doc_ids)}"
)
seen_doc_ids.add(doc_id)
first_chunk = next(chunks_for_doc)
chunks_for_doc = chain([first_chunk], chunks_for_doc)
try:
insertion_records.extend(
document_index.index(
@@ -87,9 +99,7 @@ def write_chunks_to_vector_db_with_backoff(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=(
chunks_for_doc[0].get_link() if chunks_for_doc else None
),
document_link=first_chunk.get_link(),
),
failure_message=str(e),
exception=e,

View File

@@ -185,6 +185,21 @@ def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
return False
def _prompt_contains_tool_call_history(prompt: LanguageModelInput) -> bool:
"""Check if the prompt contains any assistant messages with tool_calls.
When Anthropic's extended thinking is enabled, the API requires every
assistant message to start with a thinking block before any tool_use
blocks. Since we don't preserve thinking_blocks (they carry
cryptographic signatures that can't be reconstructed), we must skip
the thinking param whenever history contains prior tool-calling turns.
"""
from onyx.llm.models import AssistantMessage
msgs = prompt if isinstance(prompt, list) else [prompt]
return any(isinstance(msg, AssistantMessage) and msg.tool_calls for msg in msgs)
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
normalized_model_name = model_name.lower()
return any(
@@ -466,7 +481,20 @@ class LitellmLLM(LLM):
reasoning_effort
)
if budget_tokens is not None:
# Anthropic requires every assistant message with tool_use
# blocks to start with a thinking block that carries a
# cryptographic signature. We don't preserve those blocks
# across turns, so skip thinking when the history already
# contains tool-calling assistant messages. LiteLLM's
# modify_params workaround doesn't cover all providers
# (notably Bedrock).
can_enable_thinking = (
budget_tokens is not None
and not _prompt_contains_tool_call_history(prompt)
)
if can_enable_thinking:
assert budget_tokens is not None # mypy
if max_tokens is not None:
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
# and the minimum budget tokens is 1024

View File

@@ -77,7 +77,6 @@ from onyx.server.features.default_assistant.api import (
)
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.hierarchy.api import router as hierarchy_router
from onyx.server.features.hooks.api import router as hook_router
from onyx.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
@@ -439,6 +438,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:
@@ -454,7 +454,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
register_onyx_exception_handlers(application)
include_router_with_global_prefix_prepended(application, hook_router)
include_router_with_global_prefix_prepended(application, password_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, query_router)

View File

@@ -76,11 +76,18 @@ class CategorizedFiles(BaseModel):
acceptable: list[UploadFile] = Field(default_factory=list)
rejected: list[RejectedFile] = Field(default_factory=list)
acceptable_file_to_token_count: dict[str, int] = Field(default_factory=dict)
# Filenames within `acceptable` that should be stored but not indexed.
skip_indexing: set[str] = Field(default_factory=set)
# Allow FastAPI UploadFile instances
model_config = ConfigDict(arbitrary_types_allowed=True)
def _skip_token_threshold(extension: str) -> bool:
"""Return True if this file extension should bypass the token limit."""
return extension.lower() in OnyxFileExtensions.TABULAR_EXTENSIONS
def _apply_long_side_cap(width: int, height: int, cap: int) -> tuple[int, int]:
if max(width, height) <= cap:
return width, height
@@ -264,7 +271,17 @@ def categorize_uploaded_files(
token_count = count_tokens(
text_content, tokenizer, token_limit=token_threshold
)
if token_threshold is not None and token_count > token_threshold:
exceeds_threshold = (
token_threshold is not None and token_count > token_threshold
)
if exceeds_threshold and _skip_token_threshold(extension):
# Exempt extensions (e.g. spreadsheets) are accepted
# but flagged to skip indexing — only metadata is
# injected into the LLM context.
results.acceptable.append(upload)
results.acceptable_file_to_token_count[filename] = token_count
results.skip_indexing.add(filename)
elif exceeds_threshold:
results.rejected.append(
RejectedFile(
filename=filename,

View File

@@ -27,6 +27,7 @@ from onyx.auth.email_utils import send_user_email_invite
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import remove_user_from_invited_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.permissions import get_effective_permissions
from onyx.auth.schemas import UserRole
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
@@ -773,6 +774,13 @@ def _get_token_created_at(
return get_current_token_creation_postgres(user, db_session)
@router.get("/me/permissions", tags=PUBLIC_API_TAGS)
def get_current_user_permissions(
user: User = Depends(current_user),
) -> list[str]:
return sorted(p.value for p in get_effective_permissions(user))
@router.get("/me", tags=PUBLIC_API_TAGS)
def verify_user_logged_in(
request: Request,

View File

@@ -7,6 +7,7 @@ from uuid import UUID
from pydantic import BaseModel
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from onyx.db.models import User
@@ -41,6 +42,7 @@ class FullUserSnapshot(BaseModel):
id: UUID
email: str
role: UserRole
account_type: AccountType
is_active: bool
password_configured: bool
personal_name: str | None
@@ -60,6 +62,7 @@ class FullUserSnapshot(BaseModel):
id=user.id,
email=user.email,
role=user.role,
account_type=user.account_type,
is_active=user.is_active,
password_configured=user.password_configured,
personal_name=user.personal_name,

View File

@@ -9,8 +9,8 @@ def mime_type_to_chat_file_type(mime_type: str | None) -> ChatFileType:
if mime_type in OnyxMimeTypes.IMAGE_MIME_TYPES:
return ChatFileType.IMAGE
if mime_type in OnyxMimeTypes.CSV_MIME_TYPES:
return ChatFileType.CSV
if mime_type in OnyxMimeTypes.TABULAR_MIME_TYPES:
return ChatFileType.TABULAR
if mime_type in OnyxMimeTypes.DOCUMENT_MIME_TYPES:
return ChatFileType.DOC

View File

@@ -21,7 +21,6 @@ from onyx.db.notification import get_notifications
from onyx.db.notification import update_notification_last_shown
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.features.build.utils import is_onyx_craft_enabled
@@ -38,6 +37,7 @@ from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -98,7 +98,7 @@ def fetch_settings(
needs_reindexing=needs_reindexing,
onyx_craft_enabled=onyx_craft_enabled_for_user,
vector_db_enabled=not DISABLE_VECTOR_DB,
hooks_enabled=HOOKS_AVAILABLE,
hooks_enabled=not MULTI_TENANT,
version=onyx_version,
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
default_user_file_max_upload_size_mb=min(

View File

@@ -116,7 +116,7 @@ class UserSettings(Settings):
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
# document sets are unavailable.
vector_db_enabled: bool = True
# True when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
# True when hooks are available: single-tenant EE deployments only.
hooks_enabled: bool = False
# Application version, read from the ONYX_VERSION env var at startup.
version: str | None = None

View File

@@ -1,3 +1,4 @@
import io
import json
from typing import Any
from typing import cast
@@ -9,6 +10,7 @@ from typing_extensions import override
from onyx.chat.emitter import Emitter
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_chat_file_by_id
@@ -169,10 +171,13 @@ class FileReaderTool(Tool[FileReaderToolOverrideKwargs]):
chat_file = self._load_file(file_id)
# Only PLAIN_TEXT and CSV are guaranteed to contain actual text bytes.
# Only PLAIN_TEXT and TABULAR are guaranteed to contain actual text bytes.
# DOC type in a loaded file means plaintext extraction failed and the
# content is the original binary (e.g. raw PDF/DOCX bytes).
if chat_file.file_type not in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV):
if chat_file.file_type not in (
ChatFileType.PLAIN_TEXT,
ChatFileType.TABULAR,
):
raise ToolCallException(
message=f"File {file_id} is not a text file (type={chat_file.file_type})",
llm_facing_message=(
@@ -181,7 +186,19 @@ class FileReaderTool(Tool[FileReaderToolOverrideKwargs]):
)
try:
full_text = chat_file.content.decode("utf-8", errors="replace")
if chat_file.file_type == ChatFileType.PLAIN_TEXT:
full_text = chat_file.content.decode("utf-8", errors="replace")
else:
full_text = (
extract_file_text(
file=io.BytesIO(chat_file.content),
file_name=chat_file.filename or "",
break_on_unprocessable=False,
)
or ""
)
except ToolCallException:
raise
except Exception:
raise ToolCallException(
message=f"Failed to decode file {file_id}",

View File

@@ -187,7 +187,7 @@ coloredlogs==15.0.1
# via onnxruntime
courlan==1.3.2
# via trafilatura
cryptography==46.0.5
cryptography==46.0.6
# via
# authlib
# google-auth
@@ -449,7 +449,7 @@ kombu==5.5.4
# via celery
kubernetes==31.0.0
# via onyx
langchain-core==1.2.11
langchain-core==1.2.22
# via onyx
langdetect==1.0.9
# via unstructured
@@ -735,7 +735,7 @@ pyee==13.0.0
# via playwright
pygithub==2.5.0
# via onyx
pygments==2.19.2
pygments==2.20.0
# via rich
pyjwt==2.12.0
# via

View File

@@ -97,7 +97,7 @@ comm==0.2.3
# via ipykernel
contourpy==1.3.3
# via matplotlib
cryptography==46.0.5
cryptography==46.0.6
# via
# google-auth
# pyjwt
@@ -349,7 +349,7 @@ pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.12.0
# via mcp
pygments==2.19.2
pygments==2.20.0
# via
# ipython
# ipython-pygments-lexers

View File

@@ -76,7 +76,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.5
cryptography==46.0.6
# via
# google-auth
# pyjwt

View File

@@ -92,7 +92,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.5
cryptography==46.0.6
# via
# google-auth
# pyjwt

View File

@@ -5,6 +5,7 @@ import asyncio
import json
import logging
import sys
import time
from dataclasses import asdict
from dataclasses import dataclass
from pathlib import Path
@@ -27,6 +28,9 @@ INTERNAL_SEARCH_TOOL_NAME = "internal_search"
INTERNAL_SEARCH_IN_CODE_TOOL_ID = "SearchTool"
MAX_REQUEST_ATTEMPTS = 5
RETRIABLE_STATUS_CODES = {429, 500, 502, 503, 504}
QUESTION_TIMEOUT_SECONDS = 300
QUESTION_RETRY_PAUSE_SECONDS = 30
MAX_QUESTION_ATTEMPTS = 3
@dataclass(frozen=True)
@@ -109,6 +113,27 @@ def normalize_api_base(api_base: str) -> str:
return f"{normalized}/api"
def load_completed_question_ids(output_file: Path) -> set[str]:
if not output_file.exists():
return set()
completed_ids: set[str] = set()
with output_file.open("r", encoding="utf-8") as file:
for line in file:
stripped = line.strip()
if not stripped:
continue
try:
record = json.loads(stripped)
except json.JSONDecodeError:
continue
question_id = record.get("question_id")
if isinstance(question_id, str) and question_id:
completed_ids.add(question_id)
return completed_ids
def load_questions(questions_file: Path) -> list[QuestionRecord]:
if not questions_file.exists():
raise FileNotFoundError(f"Questions file not found: {questions_file}")
@@ -348,6 +373,7 @@ async def generate_answers(
api_base: str,
api_key: str,
parallelism: int,
skipped: int,
) -> None:
if parallelism < 1:
raise ValueError("`--parallelism` must be at least 1.")
@@ -382,58 +408,178 @@ async def generate_answers(
write_lock = asyncio.Lock()
completed = 0
successful = 0
stuck_count = 0
failed_questions: list[FailedQuestionRecord] = []
total = len(questions)
remaining_count = len(questions)
overall_total = remaining_count + skipped
question_durations: list[float] = []
run_start_time = time.monotonic()
def print_progress() -> None:
avg_time = (
sum(question_durations) / len(question_durations)
if question_durations
else 0.0
)
elapsed = time.monotonic() - run_start_time
eta = avg_time * (remaining_count - completed) / max(parallelism, 1)
done = skipped + completed
bar_width = 30
filled = (
int(bar_width * done / overall_total)
if overall_total
else bar_width
)
bar = "" * filled + "" * (bar_width - filled)
pct = (done / overall_total * 100) if overall_total else 100.0
parts = (
f"\r{bar} {pct:5.1f}% "
f"[{done}/{overall_total}] "
f"avg {avg_time:.1f}s/q "
f"elapsed {elapsed:.0f}s "
f"ETA {eta:.0f}s "
f"(ok:{successful} fail:{len(failed_questions)}"
)
if stuck_count:
parts += f" stuck:{stuck_count}"
if skipped:
parts += f" skip:{skipped}"
parts += ")"
sys.stderr.write(parts)
sys.stderr.flush()
print_progress()
async def process_question(question_record: QuestionRecord) -> None:
nonlocal completed
nonlocal successful
nonlocal stuck_count
try:
async with semaphore:
result = await submit_question(
session=session,
api_base=api_base,
headers=headers,
internal_search_tool_id=internal_search_tool_id,
question_record=question_record,
last_error: Exception | None = None
for attempt in range(1, MAX_QUESTION_ATTEMPTS + 1):
q_start = time.monotonic()
try:
async with semaphore:
result = await asyncio.wait_for(
submit_question(
session=session,
api_base=api_base,
headers=headers,
internal_search_tool_id=internal_search_tool_id,
question_record=question_record,
),
timeout=QUESTION_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
async with progress_lock:
stuck_count += 1
logger.warning(
"Question %s timed out after %ss (attempt %s/%s, "
"total stuck: %s) — retrying in %ss",
question_record.question_id,
QUESTION_TIMEOUT_SECONDS,
attempt,
MAX_QUESTION_ATTEMPTS,
stuck_count,
QUESTION_RETRY_PAUSE_SECONDS,
)
print_progress()
last_error = TimeoutError(
f"Timed out after {QUESTION_TIMEOUT_SECONDS}s "
f"on attempt {attempt}/{MAX_QUESTION_ATTEMPTS}"
)
except Exception as exc:
await asyncio.sleep(QUESTION_RETRY_PAUSE_SECONDS)
continue
except Exception as exc:
duration = time.monotonic() - q_start
async with progress_lock:
completed += 1
question_durations.append(duration)
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(exc),
)
)
logger.exception(
"Failed question %s (%s/%s)",
question_record.question_id,
completed,
remaining_count,
)
print_progress()
return
duration = time.monotonic() - q_start
async with write_lock:
file.write(json.dumps(asdict(result), ensure_ascii=False))
file.write("\n")
file.flush()
async with progress_lock:
completed += 1
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(exc),
)
)
logger.exception(
"Failed question %s (%s/%s)",
question_record.question_id,
completed,
total,
)
successful += 1
question_durations.append(duration)
print_progress()
return
async with write_lock:
file.write(json.dumps(asdict(result), ensure_ascii=False))
file.write("\n")
file.flush()
# All attempts exhausted due to timeouts
async with progress_lock:
completed += 1
successful += 1
logger.info("Processed %s/%s questions", completed, total)
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(last_error),
)
)
logger.error(
"Question %s failed after %s timeout attempts (%s/%s)",
question_record.question_id,
MAX_QUESTION_ATTEMPTS,
completed,
remaining_count,
)
print_progress()
await asyncio.gather(
*(process_question(question_record) for question_record in questions)
)
# Final newline after progress bar
sys.stderr.write("\n")
sys.stderr.flush()
total_elapsed = time.monotonic() - run_start_time
avg_time = (
sum(question_durations) / len(question_durations)
if question_durations
else 0.0
)
stuck_suffix = f", {stuck_count} stuck timeouts" if stuck_count else ""
resume_suffix = (
f"{skipped} previously completed, "
f"{skipped + successful}/{overall_total} overall"
if skipped
else ""
)
logger.info(
"Done: %s/%s successful in %.1fs (avg %.1fs/question%s)%s",
successful,
remaining_count,
total_elapsed,
avg_time,
stuck_suffix,
resume_suffix,
)
if failed_questions:
logger.warning(
"Completed with %s failed questions and %s successful questions.",
"%s questions failed:",
len(failed_questions),
successful,
)
for failed_question in failed_questions:
logger.warning(
@@ -453,7 +599,30 @@ def main() -> None:
raise ValueError("`--max-questions` must be at least 1 when provided.")
questions = questions[: args.max_questions]
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
completed_ids = load_completed_question_ids(args.output_file)
logger.info(
"Found %s already-answered question IDs in %s",
len(completed_ids),
args.output_file,
)
total_before_filter = len(questions)
questions = [q for q in questions if q.question_id not in completed_ids]
skipped = total_before_filter - len(questions)
if skipped:
logger.info(
"Resuming: %s/%s already answered, %s remaining",
skipped,
total_before_filter,
len(questions),
)
else:
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
if not questions:
logger.info("All questions already answered. Nothing to do.")
return
logger.info("Writing answers to %s", args.output_file)
asyncio.run(
@@ -463,6 +632,7 @@ def main() -> None:
api_base=api_base,
api_key=args.api_key,
parallelism=args.parallelism,
skipped=skipped,
)
)

View File

@@ -1,7 +1,7 @@
"""
External dependency unit tests for UserFileIndexingAdapter metadata writing.
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
Validates that prepare_enrichment produces DocMetadataAwareIndexChunk
objects with both `user_project` and `personas` fields populated correctly
based on actual DB associations.
@@ -127,7 +127,7 @@ def _make_index_chunk(user_file: UserFile) -> IndexChunk:
class TestAdapterWritesBothMetadataFields:
"""build_metadata_aware_chunks must populate user_project AND personas."""
"""prepare_enrichment must populate user_project AND personas."""
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
@@ -153,15 +153,13 @@ class TestAdapterWritesBothMetadataFields:
doc = chunk.source_document
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert aware_chunk.user_project == []
@@ -190,15 +188,13 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert project.id in aware_chunk.user_project
assert aware_chunk.personas == []
@@ -229,14 +225,13 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert project.id in aware_chunk.user_project
@@ -261,14 +256,13 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert aware_chunk.personas == []
assert aware_chunk.user_project == []
@@ -300,12 +294,11 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}

View File

@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.enums import AccountType
from onyx.db.models import User
from onyx.db.models import UserRole
from onyx.file_store.file_store import get_default_file_store
@@ -52,7 +53,12 @@ def tenant_context() -> Generator[None, None, None]:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def create_test_user(db_session: Session, email_prefix: str) -> User:
def create_test_user(
db_session: Session,
email_prefix: str,
role: UserRole = UserRole.BASIC,
account_type: AccountType = AccountType.STANDARD,
) -> User:
"""Helper to create a test user with a unique email"""
# Use UUID to ensure unique email addresses
unique_email = f"{email_prefix}_{uuid4().hex[:8]}@example.com"
@@ -68,7 +74,8 @@ def create_test_user(db_session: Session, email_prefix: str) -> User:
is_active=True,
is_superuser=False,
is_verified=True,
role=UserRole.EXT_PERM_USER,
role=role,
account_type=account_type,
)
db_session.add(user)
db_session.commit()

View File

@@ -13,16 +13,29 @@ from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import InputType
from onyx.db.enums import AccessType
from onyx.db.enums import AccountType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import PublicExternalUserGroup
from onyx.db.models import User
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.models import UserRole
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
def _create_ext_perm_user(db_session: Session, name: str) -> User:
"""Create an external-permission user for group sync tests."""
return create_test_user(
db_session,
name,
role=UserRole.EXT_PERM_USER,
account_type=AccountType.EXT_PERM_USER,
)
def _create_test_connector_credential_pair(
db_session: Session, source: DocumentSource = DocumentSource.GOOGLE_DRIVE
) -> ConnectorCredentialPair:
@@ -100,9 +113,9 @@ class TestPerformExternalGroupSync:
def test_initial_group_sync(self, db_session: Session) -> None:
"""Test syncing external groups for the first time (initial sync)"""
# Create test data
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
user3 = create_test_user(db_session, "user3")
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
user3 = _create_ext_perm_user(db_session, "user3")
cc_pair = _create_test_connector_credential_pair(db_session)
# Mock external groups data as a generator that yields the expected groups
@@ -175,9 +188,9 @@ class TestPerformExternalGroupSync:
def test_update_existing_groups(self, db_session: Session) -> None:
"""Test updating existing groups (adding/removing users)"""
# Create test data
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
user3 = create_test_user(db_session, "user3")
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
user3 = _create_ext_perm_user(db_session, "user3")
cc_pair = _create_test_connector_credential_pair(db_session)
# Initial sync with original groups
@@ -272,8 +285,8 @@ class TestPerformExternalGroupSync:
def test_remove_groups(self, db_session: Session) -> None:
"""Test removing groups (groups that no longer exist in external system)"""
# Create test data
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
cc_pair = _create_test_connector_credential_pair(db_session)
# Initial sync with multiple groups
@@ -357,7 +370,7 @@ class TestPerformExternalGroupSync:
def test_empty_group_sync(self, db_session: Session) -> None:
"""Test syncing when no groups are returned (all groups removed)"""
# Create test data
user1 = create_test_user(db_session, "user1")
user1 = _create_ext_perm_user(db_session, "user1")
cc_pair = _create_test_connector_credential_pair(db_session)
# Initial sync with groups
@@ -413,7 +426,7 @@ class TestPerformExternalGroupSync:
# Create many test users
users = []
for i in range(150): # More than the batch size of 100
users.append(create_test_user(db_session, f"user{i}"))
users.append(_create_ext_perm_user(db_session, f"user{i}"))
cc_pair = _create_test_connector_credential_pair(db_session)
@@ -452,8 +465,8 @@ class TestPerformExternalGroupSync:
def test_mixed_regular_and_public_groups(self, db_session: Session) -> None:
"""Test syncing a mix of regular and public groups"""
# Create test data
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
cc_pair = _create_test_connector_credential_pair(db_session)
def mixed_group_sync_func(

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.enums import AccountType
from onyx.db.enums import BuildSessionStatus
from onyx.db.models import BuildSession
from onyx.db.models import User
@@ -52,6 +53,7 @@ def test_user(db_session: Session, tenant_context: None) -> User: # noqa: ARG00
is_superuser=False,
is_verified=True,
role=UserRole.EXT_PERM_USER,
account_type=AccountType.EXT_PERM_USER,
)
db_session.add(user)
db_session.commit()

View File

@@ -0,0 +1,51 @@
"""
Tests that account_type is correctly set when creating users through
the internal DB functions: add_slack_user_if_not_exists and
batch_add_ext_perm_user_if_not_exists.
These functions are called by background workers (Slack bot, permission sync)
and are not exposed via API endpoints, so they must be tested directly.
"""
from sqlalchemy.orm import Session
from onyx.db.enums import AccountType
from onyx.db.models import UserRole
from onyx.db.users import add_slack_user_if_not_exists
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
def test_slack_user_creation_sets_account_type_bot(db_session: Session) -> None:
"""add_slack_user_if_not_exists sets account_type=BOT and role=SLACK_USER."""
user = add_slack_user_if_not_exists(db_session, "slack_acct_type@test.com")
assert user.role == UserRole.SLACK_USER
assert user.account_type == AccountType.BOT
def test_ext_perm_user_creation_sets_account_type(db_session: Session) -> None:
"""batch_add_ext_perm_user_if_not_exists sets account_type=EXT_PERM_USER."""
users = batch_add_ext_perm_user_if_not_exists(
db_session, ["extperm_acct_type@test.com"]
)
assert len(users) == 1
user = users[0]
assert user.role == UserRole.EXT_PERM_USER
assert user.account_type == AccountType.EXT_PERM_USER
def test_ext_perm_to_slack_upgrade_updates_role_and_account_type(
db_session: Session,
) -> None:
"""When an EXT_PERM_USER is upgraded to slack, both role and account_type update."""
email = "ext_to_slack_acct_type@test.com"
# Create as ext_perm user first
batch_add_ext_perm_user_if_not_exists(db_session, [email])
# Now "upgrade" via slack path
user = add_slack_user_if_not_exists(db_session, email)
assert user.role == UserRole.SLACK_USER
assert user.account_type == AccountType.BOT

View File

@@ -8,6 +8,7 @@ import pytest
from fastapi_users.password import PasswordHelper
from sqlalchemy.orm import Session
from onyx.db.enums import AccountType
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
@@ -46,6 +47,7 @@ def _create_admin(db_session: Session) -> User:
is_superuser=True,
is_verified=True,
role=UserRole.ADMIN,
account_type=AccountType.STANDARD,
)
db_session.add(user)
db_session.commit()

View File

@@ -1175,7 +1175,7 @@ def test_code_interpreter_receives_chat_files(
file_descriptor: FileDescriptor = {
"id": user_file.file_id,
"type": ChatFileType.CSV,
"type": ChatFileType.TABULAR,
"name": "data.csv",
"user_file_id": str(user_file.id),
}

View File

@@ -126,6 +126,15 @@ class UserManager:
return test_user
@staticmethod
def get_permissions(user: DATestUser) -> list[str]:
response = requests.get(
url=f"{API_SERVER_URL}/me/permissions",
headers=user.headers,
)
response.raise_for_status()
return response.json()
@staticmethod
def is_role(
user_to_verify: DATestUser,

View File

@@ -104,13 +104,30 @@ class UserGroupManager:
)
response.raise_for_status()
@staticmethod
def get_permissions(
user_group: DATestUserGroup,
user_performing_action: DATestUser,
) -> list[str]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/permissions",
headers=user_performing_action.headers,
)
response.raise_for_status()
return response.json()
@staticmethod
def get_all(
user_performing_action: DATestUser,
include_default: bool = False,
) -> list[UserGroup]:
params: dict[str, str] = {}
if include_default:
params["include_default"] = "true"
response = requests.get(
f"{API_SERVER_URL}/manage/admin/user-group",
headers=user_performing_action.headers,
params=params,
)
response.raise_for_status()
return [UserGroup(**ug) for ug in response.json()]

View File

@@ -1,9 +1,13 @@
from uuid import UUID
import requests
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestUser
@@ -33,3 +37,120 @@ def test_limited(reset: None) -> None: # noqa: ARG001
headers=api_key.headers,
)
assert response.status_code == 403
def _get_service_account_account_type(
admin_user: DATestUser,
api_key_user_id: UUID,
) -> AccountType:
"""Fetch the account_type of a service account user via the user listing API."""
response = requests.get(
f"{API_SERVER_URL}/manage/users",
headers=admin_user.headers,
params={"include_api_keys": "true"},
)
response.raise_for_status()
data = response.json()
user_id_str = str(api_key_user_id)
for user in data["accepted"]:
if user["id"] == user_id_str:
return AccountType(user["account_type"])
raise AssertionError(
f"Service account user {user_id_str} not found in user listing"
)
def _get_default_group_user_ids(
admin_user: DATestUser,
) -> tuple[set[str], set[str]]:
"""Return (admin_group_user_ids, basic_group_user_ids) from default groups."""
all_groups = UserGroupManager.get_all(
user_performing_action=admin_user,
include_default=True,
)
admin_group = next(
(g for g in all_groups if g.name == "Admin" and g.is_default), None
)
basic_group = next(
(g for g in all_groups if g.name == "Basic" and g.is_default), None
)
assert admin_group is not None, "Admin default group not found"
assert basic_group is not None, "Basic default group not found"
admin_ids = {str(u.id) for u in admin_group.users}
basic_ids = {str(u.id) for u in basic_group.users}
return admin_ids, basic_ids
def test_api_key_limited_service_account(reset: None) -> None: # noqa: ARG001
"""LIMITED role API key: account_type is SERVICE_ACCOUNT, no group membership."""
admin_user: DATestUser = UserManager.create(name="admin_user")
api_key: DATestAPIKey = APIKeyManager.create(
api_key_role=UserRole.LIMITED,
user_performing_action=admin_user,
)
# Verify account_type
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
assert (
account_type == AccountType.SERVICE_ACCOUNT
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
# Verify no group membership
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
user_id_str = str(api_key.user_id)
assert (
user_id_str not in admin_ids
), "LIMITED API key should NOT be in Admin default group"
assert (
user_id_str not in basic_ids
), "LIMITED API key should NOT be in Basic default group"
def test_api_key_basic_service_account(reset: None) -> None: # noqa: ARG001
"""BASIC role API key: account_type is SERVICE_ACCOUNT, in Basic group only."""
admin_user: DATestUser = UserManager.create(name="admin_user")
api_key: DATestAPIKey = APIKeyManager.create(
api_key_role=UserRole.BASIC,
user_performing_action=admin_user,
)
# Verify account_type
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
assert (
account_type == AccountType.SERVICE_ACCOUNT
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
# Verify Basic group membership
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
user_id_str = str(api_key.user_id)
assert user_id_str in basic_ids, "BASIC API key should be in Basic default group"
assert (
user_id_str not in admin_ids
), "BASIC API key should NOT be in Admin default group"
def test_api_key_admin_service_account(reset: None) -> None: # noqa: ARG001
"""ADMIN role API key: account_type is SERVICE_ACCOUNT, in Admin group only."""
admin_user: DATestUser = UserManager.create(name="admin_user")
api_key: DATestAPIKey = APIKeyManager.create(
api_key_role=UserRole.ADMIN,
user_performing_action=admin_user,
)
# Verify account_type
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
assert (
account_type == AccountType.SERVICE_ACCOUNT
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
# Verify Admin group membership
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
user_id_str = str(api_key.user_id)
assert user_id_str in admin_ids, "ADMIN API key should be in Admin default group"
assert (
user_id_str not in basic_ids
), "ADMIN API key should NOT be in Basic default group"

Some files were not shown because too many files have changed in this diff Show More