Compare commits

..

80 Commits

Author SHA1 Message Date
Dane Urban
80cf389774 . 2026-02-23 16:30:30 -08:00
Danelegend
e775aaacb7 chore: preview modal (#8665) 2026-02-23 16:29:13 -08:00
Justin Tahara
e5b08b3d92 fix(search): Improve Speed (#8430) 2026-02-23 16:29:13 -08:00
Jamison Lahman
7c91304ba2 chore(playwright): warn user if setup takes longer than usual (#8690) 2026-02-23 16:29:13 -08:00
roshan
68a292b500 fix(ui): Clean up NRF settings button styling (#8678)
Co-authored-by: Claude <noreply@anthropic.com>
2026-02-23 16:29:13 -08:00
Justin Tahara
e553b80030 fix(db): Multitenant Schema migration update (#8679) 2026-02-23 16:29:13 -08:00
Justin Tahara
f3949f8e09 chore(ods): Automated Cherry-pick backport (#8642) 2026-02-23 16:29:13 -08:00
Nikolas Garza
c7c064e296 feat(scim): Okta compatibility + provider abstraction (#8568) 2026-02-23 16:29:13 -08:00
Wenxi
68b91a8862 fix: domain rules for signup on cloud (#8671) 2026-02-23 16:29:13 -08:00
roshan
c23e5a196d fix: Handle unauthenticated state gracefully on NRF page (#8491)
Co-authored-by: Claude <noreply@anthropic.com>
2026-02-23 16:29:13 -08:00
Raunak Bhagat
093223c6c4 refactor: migrate Web Search page to SettingsLayouts + Content (#8662) 2026-02-23 16:29:13 -08:00
Danelegend
89517111d4 feat: Add code interpreter server db model (#8669) 2026-02-23 16:29:13 -08:00
Wenxi
883d4b4ceb chore: set trial api usage to 0 and show ui (#8664) 2026-02-23 16:29:13 -08:00
Dane Urban
f3672b6819 CSV rendering 2026-02-22 18:33:39 -08:00
Dane Urban
921f5d9e96 preview modal 2026-02-22 17:42:30 -08:00
SubashMohan
15fe47adc5 fix: improve connector status display, agent search tool detection, and prompt formatting (#8579) 2026-02-22 07:22:06 +00:00
SubashMohan
29958f1a52 perf(open-url): parallelize URL fetching with split connect/read timeouts (#8580) 2026-02-22 06:45:26 +00:00
SubashMohan
ac7f9838bc feat: add mixed content handler for chat and image generation packets (#8494) 2026-02-22 06:12:48 +00:00
Raunak Bhagat
d0fa4b3319 fix: limit connectors section to 3 items in Chat Preferences (#8661) 2026-02-22 01:01:42 +00:00
Raunak Bhagat
3fb4fb422e feat: refreshed admin Chat Preferences page (#8488)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-21 04:52:04 +00:00
acaprau
ba5da22ea1 chore(documentation): Add comment in contributing_guides/best_practices.md about accountability in TODOs (#8659) 2026-02-21 04:47:48 +00:00
Evan Lohn
9909049047 fix: improve eager loading behavior (#8576) 2026-02-21 04:02:34 +00:00
Nikolas Garza
c516aa3e3c fix(search): eliminate DB pool exhaustion from SearchTool parallel queries (#8651) 2026-02-21 03:56:00 +00:00
Evan Lohn
5cc6220417 feat: checkpointing mid drive (#8606) 2026-02-21 03:47:28 +00:00
acaprau
15da1e0a88 feat(opensearch): Add OpenSearch in docker compose (#8611) 2026-02-21 02:57:31 +00:00
Jamison Lahman
e9ff00890b chore(fe): fix Pinned icon hover background when card is hovered (#8656) 2026-02-21 02:37:18 +00:00
Wenxi
67747a9d93 fix: unify provider retrieval into swr hook (#8609) 2026-02-21 02:15:28 +00:00
Jamison Lahman
edfc51b439 fix(fe): Truncated tooltip is disabled when not needed to be shown (#8629) 2026-02-21 02:07:18 +00:00
Evan Lohn
ac4fba947e feat: sharepoint scalability5 (#8631) 2026-02-21 01:56:11 +00:00
Nikolas Garza
c142b2db02 chore: special sauce (#8646) 2026-02-21 01:30:36 +00:00
Jamison Lahman
fb7e7e4395 fix(fe): keep focus on input on empty (#8627) 2026-02-21 00:13:26 +00:00
Justin Tahara
113f23398e fix(celery): Guardrail for User File Processing (#8633) 2026-02-20 22:43:01 +00:00
Danelegend
5a8716026a feat: Python tool call replay packets (#8649) 2026-02-20 22:28:06 +00:00
Jamison Lahman
3389140bfd chore(deps): @radix-ui/react-tooltip: v1.1.3->v1.2.8 (#8647) 2026-02-20 22:27:49 +00:00
Justin Tahara
13109e7b81 chore(llm): Removal of Retired Models + Cleanup (#8645) 2026-02-20 21:47:53 +00:00
Jessica Singh
56ad457168 chore(auth): tests cleanup (#8559) 2026-02-20 20:51:29 +00:00
Nikolas Garza
a81aea2afc feat(scim): add scim_username column to ScimUserMapping (#8635) 2026-02-20 12:44:16 -08:00
Jamison Lahman
7cb5c9c4a6 chore(fe): replace BlinkingDot with a BlinkingBar (#8640) 2026-02-20 20:02:19 +00:00
Evan Lohn
3520c58a22 feat: sharepoint scalability4 (#8551) 2026-02-20 19:35:36 +00:00
Nikolas Garza
bd9d1bfa27 chore(helm): update code-interpreter chart repo URL to python-sandbox (#8625) 2026-02-20 11:39:08 -08:00
Evan Lohn
14416cc3db fix: search tool enabled when nothing selected (#8637) 2026-02-20 19:22:28 +00:00
Jamison Lahman
d7fce14d26 chore(devtools): upgrade ods: 0.5.7->0.6.0 (#8628) 2026-02-20 10:18:27 -08:00
Wenxi
39a8d8ed05 fix: boolean form field click on text, open url tool checkbox in default assistant, and simple tooltip rendering (#8480) 2026-02-20 17:38:29 +00:00
Jamison Lahman
82f735a434 chore(fe): rm settings page top padding (#8621) 2026-02-20 17:37:19 +00:00
Jamison Lahman
aadb58518b fix(fe): popover width can fit trigger element (#8624) 2026-02-20 07:29:21 +00:00
Jamison Lahman
0755499e0f chore(fe): whitelabel logo layout followup (#8626) 2026-02-19 23:27:37 -08:00
Danelegend
27aaf977a2 feat(code interpreter): improve produced code artifacts (#8507)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-20 07:10:09 +00:00
Evan Lohn
9f707f195e feat: sharepoint scalability3 (#8537) 2026-02-20 05:34:50 +00:00
Justin Tahara
3e35570f70 feat(vertex ai): Image Gen Historical Context (#8603) 2026-02-20 05:31:41 +00:00
Jamison Lahman
53b1bf3b2c chore(fe): fix nested button hydration error (#8599) 2026-02-20 05:14:24 +00:00
Jamison Lahman
5a3fa6b648 chore(fe): fix whitelabel logo moving on sidebar close (#8577) 2026-02-20 04:30:01 +00:00
Evan Lohn
fc6a37850b feat: delta sync sharepoint (#8532)
Co-authored-by: CE11-Kishan <CE11-Kishan@users.noreply.github.com>
2026-02-20 03:26:54 +00:00
Raunak Bhagat
aa6fec3d58 Fix/pat UI (#8617) 2026-02-19 19:23:26 -08:00
Raunak Bhagat
efa6005e36 feat(opal): Add widthVariant to Interactive.Container (#8610) 2026-02-20 03:14:30 +00:00
Nikolas Garza
921bfc72f4 fix(helm): route /openapi.json to api_server in nginx config (#8612) 2026-02-19 19:05:14 -08:00
Justin Tahara
812603152d feat(web): FE Changes for Brave Web Search 3/3 (#8597) 2026-02-20 02:43:30 +00:00
Raunak Bhagat
6779d8fbd7 feat(opal): Add Content layout component (#8534)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 02:35:27 +00:00
Justin Tahara
2c9826e4a9 feat(web): Logical Hardening for Brave Web Search 2/3 (#8595) 2026-02-20 02:09:49 +00:00
Justin Tahara
5b54687077 feat(web): Initial Framework for Brave Web Search 1/3 (#8594) 2026-02-20 01:38:19 +00:00
Raunak Bhagat
0f7e2ee674 feat(opal): Add Tag component and resync colors with Figma (#8533)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 01:14:52 +00:00
Danelegend
ea466648d9 feat: file preview from llm (#8604) 2026-02-20 00:47:03 +00:00
Evan Lohn
a402911ee6 feat: sharepoint scalability 1 (#8531) 2026-02-20 00:41:48 +00:00
Wenxi
7ae9ba807d fix(manage-users): exclude slack users from /users list (#8602) 2026-02-19 23:41:25 +00:00
Nikolas Garza
1f79223c42 feat(ods): add --continue flag and cp alias to cherry-pick (#8601) 2026-02-19 22:31:06 +00:00
Danelegend
c0c2247d5a feat: Modal Header no icon (#8596) 2026-02-19 21:59:18 +00:00
Danelegend
2989ceda41 feat: add file formatting reminder (#8524) 2026-02-19 21:51:06 +00:00
Wenxi
c825f5eca6 feat: whitelist invite setting, allow users to register and invite, new accoount blocked page (#8527) 2026-02-19 20:16:50 +00:00
Jamison Lahman
a8965def79 chore(playwright): fix chat scrolling non-determinism (#8584) 2026-02-19 19:26:09 +00:00
Jamison Lahman
59e1ad51ba chore(fe): fix drop-down overflow in API Key modal (#8574) 2026-02-19 19:15:10 +00:00
Jamison Lahman
0e70a8f826 chore(fe): remove close button from image gen tooltip (#8585) 2026-02-19 18:31:02 +00:00
SubashMohan
0891737dfd fix: update SourceTag component to use variant prop for sizing (#8582) 2026-02-19 17:27:04 +00:00
victoria reese
5a20112670 fix: define fallback only on custom metrics (#8566) 2026-02-19 09:19:44 -08:00
SubashMohan
584f2e2638 fix(ui): Fix admin UI layout, sticky headers, LLM filtering, and project view issues (#8496) 2026-02-19 14:02:22 +00:00
Yuhong Sun
aa24b16ec1 chore: Opensearch tuning (#8518) 2026-02-19 05:43:45 +00:00
acaprau
50aa9d7df6 feat(opensearch): Display percentage progress in the migration page (#8575) 2026-02-19 05:33:29 +00:00
acaprau
bfda586054 chore(opensearch): Make the migration visit from 4 independent Vespa slices concurrently (#8570) 2026-02-19 04:26:59 +00:00
roshan
e04392fbb1 feat(craft): shareable webapp URLs with two-tier access control (#8510)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-19 04:04:21 +00:00
Wenxi
e46c6c5175 fix: init all celery tasks for autodiscovery (#8539) 2026-02-19 03:45:40 +00:00
Nikolas Garza
f59792b4ac feat(metrics): add SQLAlchemy connection pool Prometheus metrics (#8543) 2026-02-19 02:32:42 +00:00
Justin Tahara
973b9456e9 fix(vertex ai): Sort Model Names (#8572) 2026-02-19 02:24:47 +00:00
375 changed files with 13856 additions and 5664 deletions

View File

@@ -8,5 +8,5 @@
## Additional Options
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
- [ ] [Optional] Override Linear Check

View File

@@ -33,7 +33,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo update
- name: Build chart dependencies

View File

@@ -0,0 +1,79 @@
name: Post-Merge Beta Cherry-Pick
on:
push:
branches:
- main
permissions:
contents: write
pull-requests: write
jobs:
cherry-pick-to-latest-release:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Resolve merged PR and checkbox state
id: gate
env:
GH_TOKEN: ${{ github.token }}
run: |
# For the commit that triggered this workflow (HEAD on main), fetch all
# associated PRs and keep only the PR that was actually merged into main
# with this exact merge commit SHA.
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
if [ "${match_count}" -gt 1 ]; then
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
fi
if [ -z "$pr_number" ]; then
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
exit 0
fi
# Read the PR body and check whether the helper checkbox is checked.
pr_body="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}" --jq '.body // ""')"
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox checked for PR #${pr_number}."
exit 0
fi
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
- name: Checkout repository
if: steps.gate.outputs.should_cherrypick == 'true'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: true
ref: main
- name: Install the latest version of uv
if: steps.gate.outputs.should_cherrypick == 'true'
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Configure git identity
if: steps.gate.outputs.should_cherrypick == 'true'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Create cherry-pick PR to latest release
if: steps.gate.outputs.should_cherrypick == 'true'
env:
GH_TOKEN: ${{ github.token }}
GITHUB_TOKEN: ${{ github.token }}
run: |
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify

View File

@@ -1,28 +0,0 @@
name: Require beta cherry-pick consideration
concurrency:
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
pull_request:
types: [opened, edited, reopened, synchronize]
permissions:
contents: read
jobs:
beta-cherrypick-check:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR body for beta cherry-pick consideration
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
echo "Cherry-pick consideration box is checked. Check passed."
exit 0
fi
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
exit 1

View File

@@ -45,9 +45,6 @@ env:
# TODO: debug why this is failing and enable
CODE_INTERPRETER_BASE_URL: http://localhost:8000
# OpenSearch
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
@@ -118,9 +115,10 @@ jobs:
- name: Create .env file for Docker Compose
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
CODE_INTERPRETER_BETA_ENABLED=true
DISABLE_TELEMETRY=true
OPENSEARCH_FOR_ONYX_ENABLED=true
EOF
- name: Set up Standard Dependencies
@@ -129,7 +127,6 @@ jobs:
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
-f docker-compose.opensearch.yml \
up -d \
minio \
relational_db \

View File

@@ -91,7 +91,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo update
- name: Install Redis operator

View File

@@ -21,15 +21,14 @@ import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, NamedTuple
from typing import NamedTuple
from alembic.config import Config
from alembic.script import ScriptDirectory
from sqlalchemy import text
from onyx.db.engine.sql_engine import is_valid_schema_name
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
from shared_configs.configs import TENANT_ID_PREFIX
@@ -105,56 +104,6 @@ def get_head_revision() -> str | None:
return script.get_current_head()
def get_schemas_needing_migration(
tenant_schemas: List[str], head_rev: str
) -> List[str]:
"""Return only schemas whose current alembic version is not at head."""
if not tenant_schemas:
return []
engine = SqlEngine.get_engine()
with engine.connect() as conn:
# Find which schemas actually have an alembic_version table
rows = conn.execute(
text(
"SELECT table_schema FROM information_schema.tables "
"WHERE table_name = 'alembic_version' "
"AND table_schema = ANY(:schemas)"
),
{"schemas": tenant_schemas},
)
schemas_with_table = set(row[0] for row in rows)
# Schemas without the table definitely need migration
needs_migration = [s for s in tenant_schemas if s not in schemas_with_table]
if not schemas_with_table:
return needs_migration
# Validate schema names before interpolating into SQL
for schema in schemas_with_table:
if not is_valid_schema_name(schema):
raise ValueError(f"Invalid schema name: {schema}")
# Single query to get every schema's current revision at once.
# Use integer tags instead of interpolating schema names into
# string literals to avoid quoting issues.
schema_list = list(schemas_with_table)
union_parts = [
f'SELECT {i} AS idx, version_num FROM "{schema}".alembic_version'
for i, schema in enumerate(schema_list)
]
rows = conn.execute(text(" UNION ALL ".join(union_parts)))
version_by_schema = {schema_list[row[0]]: row[1] for row in rows}
needs_migration.extend(
s for s in schemas_with_table if version_by_schema.get(s) != head_rev
)
return needs_migration
def run_migrations_parallel(
schemas: list[str],
max_workers: int,

View File

@@ -0,0 +1,28 @@
"""add scim_username to scim_user_mapping
Revision ID: 0bb4558f35df
Revises: 631fd2504136
Create Date: 2026-02-20 10:45:30.340188
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0bb4558f35df"
down_revision = "631fd2504136"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("scim_username", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_username")

View File

@@ -0,0 +1,32 @@
"""add approx_chunk_count_in_vespa to opensearch tenant migration
Revision ID: 631fd2504136
Revises: c7f2e1b4a9d3
Create Date: 2026-02-18 21:07:52.831215
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "631fd2504136"
down_revision = "c7f2e1b4a9d3"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"approx_chunk_count_in_vespa",
sa.Integer(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("opensearch_tenant_migration_record", "approx_chunk_count_in_vespa")

View File

@@ -0,0 +1,31 @@
"""code interpreter server model
Revision ID: 7cb492013621
Revises: 0bb4558f35df
Create Date: 2026-02-22 18:54:54.007265
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7cb492013621"
down_revision = "0bb4558f35df"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"code_interpreter_server",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column(
"server_enabled", sa.Boolean, nullable=False, server_default=sa.true()
),
)
def downgrade() -> None:
op.drop_table("code_interpreter_server")

View File

@@ -0,0 +1,31 @@
"""add sharing_scope to build_session
Revision ID: c7f2e1b4a9d3
Revises: 19c0ccb01687
Create Date: 2026-02-17 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "c7f2e1b4a9d3"
down_revision = "19c0ccb01687"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"build_session",
sa.Column(
"sharing_scope",
sa.String(),
nullable=False,
server_default="private",
),
)
def downgrade() -> None:
op.drop_column("build_session", "sharing_scope")

View File

@@ -127,9 +127,14 @@ class ScimDAL(DAL):
self,
external_id: str,
user_id: UUID,
scim_username: str | None = None,
) -> ScimUserMapping:
"""Create a mapping between a SCIM externalId and an Onyx user."""
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
mapping = ScimUserMapping(
external_id=external_id,
user_id=user_id,
scim_username=scim_username,
)
self._session.add(mapping)
self._session.flush()
return mapping
@@ -248,11 +253,11 @@ class ScimDAL(DAL):
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[User, str | None]], int]:
) -> tuple[list[tuple[User, ScimUserMapping | None]], int]:
"""Query users with optional SCIM filter and pagination.
Returns:
A tuple of (list of (user, external_id) pairs, total_count).
A tuple of (list of (user, mapping) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
@@ -292,33 +297,104 @@ class ScimDAL(DAL):
users = list(
self._session.scalars(
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
).all()
)
.unique()
.all()
)
# Batch-fetch external IDs to avoid N+1 queries
ext_id_map = self._get_user_external_ids([u.id for u in users])
return [(u, ext_id_map.get(u.id)) for u in users], total
# Batch-fetch SCIM mappings to avoid N+1 queries
mapping_map = self._get_user_mappings_batch([u.id for u in users])
return [(u, mapping_map.get(u.id)) for u in users], total
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
def sync_user_external_id(
self,
user_id: UUID,
new_external_id: str | None,
scim_username: str | None = None,
) -> None:
"""Create, update, or delete the external ID mapping for a user."""
mapping = self.get_user_mapping_by_user_id(user_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
else:
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
)
elif mapping:
self.delete_user_mapping(mapping.id)
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
"""Batch-fetch external IDs for a list of user IDs."""
def _get_user_mappings_batch(
self, user_ids: list[UUID]
) -> dict[UUID, ScimUserMapping]:
"""Batch-fetch SCIM user mappings keyed by user ID."""
if not user_ids:
return {}
mappings = self._session.scalars(
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
).all()
return {m.user_id: m.external_id for m in mappings}
return {m.user_id: m for m in mappings}
def get_user_groups(self, user_id: UUID) -> list[tuple[int, str]]:
"""Get groups a user belongs to as ``(group_id, group_name)`` pairs.
Excludes groups marked for deletion.
"""
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_id == user_id)
).all()
group_ids = [r.user_group_id for r in rels]
if not group_ids:
return []
groups = self._session.scalars(
select(UserGroup).where(
UserGroup.id.in_(group_ids),
UserGroup.is_up_for_deletion.is_(False),
)
).all()
return [(g.id, g.name) for g in groups]
def get_users_groups_batch(
self, user_ids: list[UUID]
) -> dict[UUID, list[tuple[int, str]]]:
"""Batch-fetch group memberships for multiple users.
Returns a mapping of ``user_id → [(group_id, group_name), ...]``.
Avoids N+1 queries when building user list responses.
"""
if not user_ids:
return {}
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_id.in_(user_ids))
).all()
group_ids = list({r.user_group_id for r in rels})
if not group_ids:
return {}
groups = self._session.scalars(
select(UserGroup).where(
UserGroup.id.in_(group_ids),
UserGroup.is_up_for_deletion.is_(False),
)
).all()
groups_by_id = {g.id: g.name for g in groups}
result: dict[UUID, list[tuple[int, str]]] = {}
for r in rels:
if r.user_id and r.user_group_id in groups_by_id:
result.setdefault(r.user_id, []).append(
(r.user_group_id, groups_by_id[r.user_group_id])
)
return result
# ------------------------------------------------------------------
# Group mapping operations
@@ -483,9 +559,13 @@ class ScimDAL(DAL):
if not user_ids:
return []
users = self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
).all()
users = (
self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
)
.unique()
.all()
)
users_by_id = {u.id: u for u in users}
return [
@@ -504,9 +584,13 @@ class ScimDAL(DAL):
"""
if not uuids:
return []
existing_users = self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
).all()
existing_users = (
self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
)
.unique()
.all()
)
existing_ids = {u.id for u in existing_users}
return [uid for uid in uuids if uid not in existing_ids]

View File

@@ -9,6 +9,7 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
@@ -18,11 +19,15 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
from onyx.db.models import User
@@ -195,8 +200,60 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
return db_session.scalar(stmt)
def _add_user_group_snapshot_eager_loads(
stmt: Select,
) -> Select:
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
return stmt.options(
selectinload(UserGroup.users),
selectinload(UserGroup.user_group_relationships),
selectinload(UserGroup.cc_pair_relationships)
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
.options(
selectinload(ConnectorCredentialPair.connector),
selectinload(ConnectorCredentialPair.credential).selectinload(
Credential.user
),
),
selectinload(UserGroup.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(UserGroup.personas).options(
selectinload(Persona.tools),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
selectinload(Persona.groups),
),
)
def fetch_user_groups(
db_session: Session, only_up_to_date: bool = True
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -209,6 +266,8 @@ def fetch_user_groups(
db_session (Session): The SQLAlchemy session used to query the database.
only_up_to_date (bool, optional): Flag to determine whether to filter the results
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -216,11 +275,16 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
return db_session.scalars(stmt).all()
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
def fetch_user_groups_for_user(
db_session: Session, user_id: UUID, only_curator_groups: bool = False
db_session: Session,
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -230,7 +294,9 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
return db_session.scalars(stmt).all()
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
def construct_document_id_select_by_usergroup(

View File

@@ -69,7 +69,7 @@ def _graph_api_get(
continue
resp.raise_for_status()
return resp.json()
except (_requests.ConnectionError, _requests.Timeout):
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
if attempt < GRAPH_API_MAX_RETRIES:
wait = min(2**attempt, 60)
logger.warning(

View File

@@ -34,7 +34,7 @@ class SendSearchQueryRequest(BaseModel):
filters: BaseFilters | None = None
num_docs_fed_to_llm_selection: int | None = None
run_query_expansion: bool = False
num_hits: int = 50
num_hits: int = 30
include_content: bool = False
stream: bool = False

View File

@@ -26,12 +26,10 @@ from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimError
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimResourceType
@@ -41,6 +39,8 @@ from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.providers.base import get_default_provider
from ee.onyx.server.scim.providers.base import ScimProvider
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
@@ -53,7 +53,6 @@ from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
@@ -63,6 +62,18 @@ scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
def _get_provider(
_token: ScimToken = Depends(verify_scim_token),
) -> ScimProvider:
"""Resolve the SCIM provider for the current request.
Currently returns OktaProvider for all requests. When multi-provider
support is added (ENG-3652), this will resolve based on token metadata
or tenant configuration — no endpoint changes required.
"""
return get_default_provider()
# ---------------------------------------------------------------------------
# Service Discovery Endpoints (unauthenticated)
# ---------------------------------------------------------------------------
@@ -100,28 +111,6 @@ def _scim_error_response(status: int, detail: str) -> JSONResponse:
)
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
"""Convert an Onyx User to a SCIM User resource representation."""
name = None
if user.personal_name:
parts = user.personal_name.split(" ", 1)
name = ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
return ScimUserResource(
id=str(user.id),
externalId=external_id,
userName=user.email,
name=name,
emails=[ScimEmail(value=user.email, type="work", primary=True)],
active=user.is_active,
meta=ScimMeta(resourceType="User"),
)
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
check_fn = fetch_ee_implementation_or_noop(
@@ -155,9 +144,10 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
"""
if not name:
return None
return name.formatted or " ".join(
part for part in [name.givenName, name.familyName] if part
)
# Build from givenName/familyName first — IdPs like Okta may send a stale
# ``formatted`` value while updating the individual name components.
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
return parts or name.formatted
# ---------------------------------------------------------------------------
@@ -171,6 +161,7 @@ def list_users(
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List users with optional SCIM filter and pagination."""
@@ -183,12 +174,19 @@ def list_users(
return _scim_error_response(400, str(e))
try:
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
users_with_mappings, total = dal.list_users(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
user_groups_map = dal.get_users_groups_batch([u.id for u, _ in users_with_mappings])
resources: list[ScimUserResource | ScimGroupResource] = [
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=user_groups_map.get(user.id, []),
scim_username=mapping.scim_username if mapping else None,
)
for user, mapping in users_with_mappings
]
return ScimListResponse(
@@ -203,6 +201,7 @@ def list_users(
def get_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Get a single user by ID."""
@@ -215,20 +214,26 @@ def get_user(
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
return _user_to_scim(user, mapping.external_id if mapping else None)
return provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=dal.get_user_groups(user.id),
scim_username=mapping.scim_username if mapping else None,
)
@scim_router.post("/Users", status_code=201, response_model=None)
def create_user(
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Create a new user from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
email = user_resource.userName.strip().lower()
email = user_resource.userName.strip()
# externalId is how the IdP correlates this user on subsequent requests.
# Without it, the IdP can't find the user and will try to re-create,
@@ -264,11 +269,14 @@ def create_user(
# Create SCIM mapping (externalId is validated above, always present)
external_id = user_resource.externalId
dal.create_user_mapping(external_id=external_id, user_id=user.id)
scim_username = user_resource.userName.strip()
dal.create_user_mapping(
external_id=external_id, user_id=user.id, scim_username=scim_username
)
dal.commit()
return _user_to_scim(user, external_id)
return provider.build_user_resource(user, external_id, scim_username=scim_username)
@scim_router.put("/Users/{user_id}", response_model=None)
@@ -276,6 +284,7 @@ def replace_user(
user_id: str,
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Replace a user entirely (RFC 7644 §3.5.1)."""
@@ -293,19 +302,27 @@ def replace_user(
if seat_error:
return _scim_error_response(403, seat_error)
personal_name = _scim_name_to_str(user_resource.name)
dal.update_user(
user,
email=user_resource.userName.strip().lower(),
email=user_resource.userName.strip(),
is_active=user_resource.active,
personal_name=_scim_name_to_str(user_resource.name),
personal_name=personal_name,
)
new_external_id = user_resource.externalId
dal.sync_user_external_id(user.id, new_external_id)
scim_username = user_resource.userName.strip()
dal.sync_user_external_id(user.id, new_external_id, scim_username=scim_username)
dal.commit()
return _user_to_scim(user, new_external_id)
return provider.build_user_resource(
user,
new_external_id,
groups=dal.get_user_groups(user.id),
scim_username=scim_username,
)
@scim_router.patch("/Users/{user_id}", response_model=None)
@@ -313,6 +330,7 @@ def patch_user(
user_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Partially update a user (RFC 7644 §3.5.2).
@@ -330,11 +348,19 @@ def patch_user(
mapping = dal.get_user_mapping_by_user_id(user.id)
external_id = mapping.external_id if mapping else None
current_scim_username = mapping.scim_username if mapping else None
current = _user_to_scim(user, external_id)
current = provider.build_user_resource(
user,
external_id,
groups=dal.get_user_groups(user.id),
scim_username=current_scim_username,
)
try:
patched = apply_user_patch(patch_request.Operations, current)
patched = apply_user_patch(
patch_request.Operations, current, provider.ignored_patch_paths
)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
@@ -345,22 +371,40 @@ def patch_user(
if seat_error:
return _scim_error_response(403, seat_error)
# Track the scim_username — if userName was patched, update it
new_scim_username = patched.userName.strip() if patched.userName else None
# If displayName was explicitly patched (different from the original), use
# it as personal_name directly. Otherwise, derive from name components.
personal_name: str | None
if patched.displayName and patched.displayName != current.displayName:
personal_name = patched.displayName
else:
personal_name = _scim_name_to_str(patched.name)
dal.update_user(
user,
email=(
patched.userName.strip().lower()
if patched.userName.lower() != user.email
patched.userName.strip()
if patched.userName.strip().lower() != user.email.lower()
else None
),
is_active=patched.active if patched.active != user.is_active else None,
personal_name=_scim_name_to_str(patched.name),
personal_name=personal_name,
)
dal.sync_user_external_id(user.id, patched.externalId)
dal.sync_user_external_id(
user.id, patched.externalId, scim_username=new_scim_username
)
dal.commit()
return _user_to_scim(user, patched.externalId)
return provider.build_user_resource(
user,
patched.externalId,
groups=dal.get_user_groups(user.id),
scim_username=new_scim_username,
)
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
@@ -398,24 +442,6 @@ def delete_user(
# ---------------------------------------------------------------------------
def _group_to_scim(
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Convert an Onyx UserGroup to a SCIM Group resource."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
"""Parse *group_id* as int, look up the group, or return a 404 error."""
try:
@@ -474,6 +500,7 @@ def list_groups(
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List groups with optional SCIM filter and pagination."""
@@ -491,7 +518,7 @@ def list_groups(
return _scim_error_response(400, str(e))
resources: list[ScimUserResource | ScimGroupResource] = [
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
provider.build_group_resource(group, dal.get_group_members(group.id), ext_id)
for group, ext_id in groups_with_ext_ids
]
@@ -507,6 +534,7 @@ def list_groups(
def get_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Get a single group by ID."""
@@ -521,13 +549,16 @@ def get_group(
mapping = dal.get_group_mapping_by_group_id(group.id)
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, mapping.external_id if mapping else None)
return provider.build_group_resource(
group, members, mapping.external_id if mapping else None
)
@scim_router.post("/Groups", status_code=201, response_model=None)
def create_group(
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Create a new group from a SCIM provisioning request."""
@@ -565,7 +596,7 @@ def create_group(
dal.commit()
members = dal.get_group_members(db_group.id)
return _group_to_scim(db_group, members, external_id)
return provider.build_group_resource(db_group, members, external_id)
@scim_router.put("/Groups/{group_id}", response_model=None)
@@ -573,6 +604,7 @@ def replace_group(
group_id: str,
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Replace a group entirely (RFC 7644 §3.5.1)."""
@@ -595,7 +627,7 @@ def replace_group(
dal.commit()
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, group_resource.externalId)
return provider.build_group_resource(group, members, group_resource.externalId)
@scim_router.patch("/Groups/{group_id}", response_model=None)
@@ -603,6 +635,7 @@ def patch_group(
group_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Partially update a group (RFC 7644 §3.5.2).
@@ -621,11 +654,11 @@ def patch_group(
external_id = mapping.external_id if mapping else None
current_members = dal.get_group_members(group.id)
current = _group_to_scim(group, current_members, external_id)
current = provider.build_group_resource(group, current_members, external_id)
try:
patched, added_ids, removed_ids = apply_group_patch(
patch_request.Operations, current
patch_request.Operations, current, provider.ignored_patch_paths
)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
@@ -652,7 +685,7 @@ def patch_group(
dal.commit()
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, patched.externalId)
return provider.build_group_resource(group, members, patched.externalId)
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)

View File

@@ -63,6 +63,13 @@ class ScimMeta(BaseModel):
location: str | None = None
class ScimUserGroupRef(BaseModel):
"""Group reference within a User resource (RFC 7643 §4.1.2, read-only)."""
value: str
display: str | None = None
class ScimUserResource(BaseModel):
"""SCIM User resource representation (RFC 7643 §4.1).
@@ -76,8 +83,10 @@ class ScimUserResource(BaseModel):
externalId: str | None = None # IdP's identifier for this user
userName: str # Typically the user's email address
name: ScimName | None = None
displayName: str | None = None
emails: list[ScimEmail] = Field(default_factory=list)
active: bool = True
groups: list[ScimUserGroupRef] = Field(default_factory=list)
meta: ScimMeta | None = None
@@ -121,12 +130,40 @@ class ScimPatchOperationType(str, Enum):
REMOVE = "remove"
class ScimPatchResourceValue(BaseModel):
"""Partial resource dict for path-less PATCH replace operations.
When an IdP sends a PATCH without a ``path``, the ``value`` is a dict
of resource attributes to set. IdPs may include read-only fields
(``id``, ``schemas``, ``meta``) alongside actual changes — these are
stripped by the provider's ``ignored_patch_paths`` before processing.
``extra="allow"`` lets unknown attributes pass through so the patch
handler can decide what to do with them (ignore or reject).
"""
model_config = ConfigDict(extra="allow")
active: bool | None = None
userName: str | None = None
displayName: str | None = None
externalId: str | None = None
name: ScimName | None = None
members: list[ScimGroupMember] | None = None
id: str | None = None
schemas: list[str] | None = None
meta: ScimMeta | None = None
ScimPatchValue = str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None
class ScimPatchOperation(BaseModel):
"""Single PATCH operation (RFC 7644 §3.5.2)."""
op: ScimPatchOperationType
path: str | None = None
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
value: ScimPatchValue = None
class ScimPatchRequest(BaseModel):

View File

@@ -16,9 +16,12 @@ from __future__ import annotations
import re
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchResourceValue
from ee.onyx.server.scim.models import ScimPatchValue
from ee.onyx.server.scim.models import ScimUserResource
@@ -41,9 +44,15 @@ _MEMBER_FILTER_RE = re.compile(
def apply_user_patch(
operations: list[ScimPatchOperation],
current: ScimUserResource,
ignored_paths: frozenset[str] = frozenset(),
) -> ScimUserResource:
"""Apply SCIM PATCH operations to a user resource.
Args:
operations: The PATCH operations to apply.
current: The current user resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns a new ``ScimUserResource`` with the modifications applied.
The original object is not mutated.
@@ -55,9 +64,9 @@ def apply_user_patch(
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_user_replace(op, data, name_data)
_apply_user_replace(op, data, name_data, ignored_paths)
elif op.op == ScimPatchOperationType.ADD:
_apply_user_replace(op, data, name_data)
_apply_user_replace(op, data, name_data, ignored_paths)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on User resource"
@@ -71,30 +80,34 @@ def _apply_user_replace(
op: ScimPatchOperation,
data: dict,
name_data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace/add operation to user data."""
path = (op.path or "").lower()
if not path:
# No path — value is a dict of top-level attributes to set
if isinstance(op.value, dict):
for key, val in op.value.items():
_set_user_field(key.lower(), val, data, name_data)
# No path — value is a resource dict of top-level attributes to set
if isinstance(op.value, ScimPatchResourceValue):
for key, val in op.value.model_dump(exclude_unset=True).items():
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
_set_user_field(path, op.value, data, name_data)
_set_user_field(path, op.value, data, name_data, ignored_paths)
def _set_user_field(
path: str,
value: str | bool | dict | list | None,
value: ScimPatchValue,
data: dict,
name_data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Set a single field on user data by SCIM path."""
if path == "active":
if path in ignored_paths:
return
elif path == "active":
data["active"] = value
elif path == "username":
data["userName"] = value
@@ -107,7 +120,7 @@ def _set_user_field(
elif path == "name.formatted":
name_data["formatted"] = value
elif path == "displayname":
# Some IdPs send displayName on users; map to formatted name
data["displayName"] = value
name_data["formatted"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
@@ -116,9 +129,15 @@ def _set_user_field(
def apply_group_patch(
operations: list[ScimPatchOperation],
current: ScimGroupResource,
ignored_paths: frozenset[str] = frozenset(),
) -> tuple[ScimGroupResource, list[str], list[str]]:
"""Apply SCIM PATCH operations to a group resource.
Args:
operations: The PATCH operations to apply.
current: The current group resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns:
A tuple of (modified group, added member IDs, removed member IDs).
The caller uses the member ID lists to update the database.
@@ -133,7 +152,9 @@ def apply_group_patch(
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
_apply_group_replace(
op, data, current_members, added_ids, removed_ids, ignored_paths
)
elif op.op == ScimPatchOperationType.ADD:
_apply_group_add(op, current_members, added_ids)
elif op.op == ScimPatchOperationType.REMOVE:
@@ -154,38 +175,48 @@ def _apply_group_replace(
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace operation to group data."""
path = (op.path or "").lower()
if not path:
if isinstance(op.value, dict):
for key, val in op.value.items():
if isinstance(op.value, ScimPatchResourceValue):
dumped = op.value.model_dump(exclude_unset=True)
for key, val in dumped.items():
if key.lower() == "members":
_replace_members(val, current_members, added_ids, removed_ids)
else:
_set_group_field(key.lower(), val, data)
_set_group_field(key.lower(), val, data, ignored_paths)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
if path == "members":
_replace_members(op.value, current_members, added_ids, removed_ids)
_replace_members(
_members_to_dicts(op.value), current_members, added_ids, removed_ids
)
return
_set_group_field(path, op.value, data)
_set_group_field(path, op.value, data, ignored_paths)
def _members_to_dicts(
value: str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None,
) -> list[dict]:
"""Convert a member list value to a list of dicts for internal processing."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
return [m.model_dump(exclude_none=True) for m in value]
def _replace_members(
value: str | list | dict | bool | None,
value: list[dict],
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
) -> None:
"""Replace the entire group member list."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
old_ids = {m["value"] for m in current_members}
new_ids = {m.get("value", "") for m in value}
@@ -197,11 +228,14 @@ def _replace_members(
def _set_group_field(
path: str,
value: str | bool | dict | list | None,
value: ScimPatchValue,
data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Set a single field on group data by SCIM path."""
if path == "displayname":
if path in ignored_paths:
return
elif path == "displayname":
data["displayName"] = value
elif path == "externalid":
data["externalId"] = value
@@ -223,8 +257,10 @@ def _apply_group_add(
if not isinstance(op.value, list):
raise ScimPatchError("Add members requires a list value")
member_dicts = [m.model_dump(exclude_none=True) for m in op.value]
existing_ids = {m["value"] for m in members}
for member_data in op.value:
for member_data in member_dicts:
member_id = member_data.get("value", "")
if member_id and member_id not in existing_ids:
members.append(member_data)

View File

@@ -0,0 +1,123 @@
"""Base SCIM provider abstraction."""
from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from uuid import UUID
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserGroupRef
from ee.onyx.server.scim.models import ScimUserResource
from onyx.db.models import User
from onyx.db.models import UserGroup
class ScimProvider(ABC):
"""Base class for provider-specific SCIM behavior.
Subclass this to handle IdP-specific quirks. The base class provides
RFC 7643-compliant response builders that populate all standard fields.
"""
@property
@abstractmethod
def name(self) -> str:
"""Short identifier for this provider (e.g. ``"okta"``)."""
...
@property
@abstractmethod
def ignored_patch_paths(self) -> frozenset[str]:
"""SCIM attribute paths to silently skip in PATCH value-object dicts.
IdPs may include read-only or meta fields alongside actual changes
(e.g. Okta sends ``{"id": "...", "active": false}``). Paths listed
here are silently dropped instead of raising an error.
"""
...
def build_user_resource(
self,
user: User,
external_id: str | None = None,
groups: list[tuple[int, str]] | None = None,
scim_username: str | None = None,
) -> ScimUserResource:
"""Build a SCIM User response from an Onyx User.
Args:
user: The Onyx user model.
external_id: The IdP's external identifier for this user.
groups: List of ``(group_id, group_name)`` tuples for the
``groups`` read-only attribute. Pass ``None`` or ``[]``
for newly-created users.
scim_username: The original-case userName from the IdP. Falls
back to ``user.email`` (lowercase) when not available.
"""
group_refs = [
ScimUserGroupRef(value=str(gid), display=gname)
for gid, gname in (groups or [])
]
# Use original-case userName if stored, otherwise fall back to the
# lowercased email from the User model.
username = scim_username or user.email
return ScimUserResource(
id=str(user.id),
externalId=external_id,
userName=username,
name=self._build_scim_name(user),
displayName=user.personal_name,
emails=[ScimEmail(value=username, type="work", primary=True)],
active=user.is_active,
groups=group_refs,
meta=ScimMeta(resourceType="User"),
)
def build_group_resource(
self,
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Build a SCIM Group response from an Onyx UserGroup."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
@staticmethod
def _build_scim_name(user: User) -> ScimName | None:
"""Extract SCIM name components from a user's personal name."""
if not user.personal_name:
return None
parts = user.personal_name.split(" ", 1)
return ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
def get_default_provider() -> ScimProvider:
"""Return the default SCIM provider.
Currently returns ``OktaProvider`` since Okta is the primary supported
IdP. When provider detection is added (via token metadata or tenant
config), this can be replaced with dynamic resolution.
"""
from ee.onyx.server.scim.providers.okta import OktaProvider
return OktaProvider()

View File

@@ -0,0 +1,25 @@
"""Okta SCIM provider."""
from __future__ import annotations
from ee.onyx.server.scim.providers.base import ScimProvider
class OktaProvider(ScimProvider):
"""Okta SCIM provider.
Okta behavioral notes:
- Uses ``PATCH {"active": false}`` for deprovisioning (not DELETE)
- Sends path-less PATCH with value dicts containing extra fields
(``id``, ``schemas``)
- Expects ``displayName`` and ``groups`` in user responses
- Only uses ``eq`` operator for ``userName`` filter
"""
@property
def name(self) -> str:
return "okta"
@property
def ignored_patch_paths(self) -> frozenset[str]:
return frozenset({"id", "schemas", "meta"})

View File

@@ -37,12 +37,15 @@ def list_user_groups(
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
user_groups = fetch_user_groups(
db_session, only_up_to_date=False, eager_load_for_snapshot=True
)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]

View File

@@ -53,7 +53,8 @@ class UserGroup(BaseModel):
id=cc_pair_relationship.cc_pair.id,
name=cc_pair_relationship.cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_relationship.cc_pair.connector
cc_pair_relationship.cc_pair.connector,
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential

View File

@@ -121,6 +121,7 @@ from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.settings.store import load_settings
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
@@ -137,6 +138,8 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
def is_user_admin(user: User) -> bool:
return user.role == UserRole.ADMIN
@@ -208,22 +211,34 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
return int(value.decode("utf-8")) == 1
def workspace_invite_only_enabled() -> bool:
settings = load_settings()
return settings.invite_only_enabled
def verify_email_is_invited(email: str) -> None:
if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}:
# SSO providers manage membership; allow JIT provisioning regardless of invites
return
whitelist = get_invited_users()
if not whitelist:
if not workspace_invite_only_enabled():
return
whitelist = get_invited_users()
if not email:
raise PermissionError("Email must be specified")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email must be specified"},
)
try:
email_info = validate_email(email, check_deliverability=False)
except EmailUndeliverableError:
raise PermissionError("Email is not valid")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email is not valid"},
)
for email_whitelist in whitelist:
try:
@@ -240,7 +255,13 @@ def verify_email_is_invited(email: str) -> None:
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
return
raise PermissionError("User not on allowed user whitelist")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": REGISTER_INVITE_ONLY_CODE,
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
},
)
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
@@ -256,13 +277,32 @@ def verify_email_domain(email: str) -> None:
detail="Email is not valid",
)
domain = email.split("@")[-1].lower()
local_part, domain = email.split("@")
domain = domain.lower()
if AUTH_TYPE == AuthType.CLOUD:
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
if domain == "googlemail.com":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
)
if "+" in local_part and domain != "onyx.app":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
},
)
# Check if email uses a disposable/temporary domain
if is_disposable_email(email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
detail={
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
},
)
# Check domain whitelist if configured
@@ -1650,7 +1690,10 @@ def get_oauth_router(
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))
# Use WEB_DOMAIN instead of request.url_for() to prevent host
# header poisoning — request.url_for() trusts the Host header.
callback_path = request.app.url_path_for(callback_route_name)
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
next_url = request.query_params.get("next", "/")

View File

@@ -1,10 +0,0 @@
"""Celery tasks for hierarchy fetching."""
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
check_for_hierarchy_fetching,
)
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
connector_hierarchy_fetching_task,
)
__all__ = ["check_for_hierarchy_fetching", "connector_hierarchy_fetching_task"]

View File

@@ -41,3 +41,14 @@ assert (
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S = 30 # 30 seconds.
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
# WARNING: Do not change these values without knowing what changes also need to
# be made to OpenSearchTenantMigrationRecord.
GET_VESPA_CHUNKS_PAGE_SIZE = 500
GET_VESPA_CHUNKS_SLICE_COUNT = 4
# String used to indicate in the vespa_visit_continuation_token mapping that the
# slice has finished and there is nothing left to visit.
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN = (
"FINISHED_VISITING_SLICE_CONTINUATION_TOKEN"
)

View File

@@ -8,6 +8,12 @@ from celery import Task
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.opensearch_migration.constants import (
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
GET_VESPA_CHUNKS_PAGE_SIZE,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
)
@@ -47,7 +53,13 @@ from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
GET_VESPA_CHUNKS_PAGE_SIZE = 1000
def is_continuation_token_done_for_all_slices(
continuation_token_map: dict[int, str | None],
) -> bool:
return all(
continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
for continuation_token in continuation_token_map.values()
)
# shared_task allows this task to be shared across celery app instances.
@@ -76,11 +88,15 @@ def migrate_chunks_from_vespa_to_opensearch_task(
Uses Vespa's Visit API to iterate through ALL chunks in bulk (not
per-document), transform them, and index them into OpenSearch. Progress is
tracked via a continuation token stored in the
tracked via a continuation token map stored in the
OpenSearchTenantMigrationRecord.
The first time we see no continuation token and non-zero chunks migrated, we
consider the migration complete and all subsequent invocations are no-ops.
The first time we see no continuation token map and non-zero chunks
migrated, we consider the migration complete and all subsequent invocations
are no-ops.
We divide the index into GET_VESPA_CHUNKS_SLICE_COUNT independent slices
where progress is tracked for each slice.
Returns:
None if OpenSearch migration is not enabled, or if the lock could not be
@@ -153,15 +169,28 @@ def migrate_chunks_from_vespa_to_opensearch_task(
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
)
approx_chunk_count_in_vespa: int | None = None
get_chunk_count_start_time = time.monotonic()
try:
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
except Exception:
task_logger.exception(
"Error getting approximate chunk count in Vespa. Moving on..."
)
task_logger.debug(
f"Took {time.monotonic() - get_chunk_count_start_time:.3f} seconds to attempt to get "
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
)
while (
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
and lock.owned()
):
(
continuation_token,
continuation_token_map,
total_chunks_migrated,
) = get_vespa_visit_state(db_session)
if continuation_token is None and total_chunks_migrated > 0:
if is_continuation_token_done_for_all_slices(continuation_token_map):
task_logger.info(
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
f"Total chunks migrated: {total_chunks_migrated}."
@@ -170,19 +199,19 @@ def migrate_chunks_from_vespa_to_opensearch_task(
break
task_logger.debug(
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
f"Continuation token: {continuation_token}"
f"Continuation token map: {continuation_token_map}"
)
get_vespa_chunks_start_time = time.monotonic()
raw_vespa_chunks, next_continuation_token = (
raw_vespa_chunks, next_continuation_token_map = (
vespa_document_index.get_all_raw_document_chunks_paginated(
continuation_token=continuation_token,
continuation_token_map=continuation_token_map,
page_size=GET_VESPA_CHUNKS_PAGE_SIZE,
)
)
task_logger.debug(
f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} "
f"seconds. Next continuation token: {next_continuation_token}"
f"seconds. Next continuation token map: {next_continuation_token_map}"
)
opensearch_document_chunks, errored_chunks = (
@@ -212,14 +241,11 @@ def migrate_chunks_from_vespa_to_opensearch_task(
total_chunks_errored_this_task += len(errored_chunks)
update_vespa_visit_progress_with_commit(
db_session,
continuation_token=next_continuation_token,
continuation_token_map=next_continuation_token_map,
chunks_processed=len(opensearch_document_chunks),
chunks_errored=len(errored_chunks),
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
)
if next_continuation_token is None and len(raw_vespa_chunks) == 0:
task_logger.info("Vespa reported no more chunks to migrate.")
break
except Exception:
traceback.print_exc()
task_logger.exception("Error in the OpenSearch migration task.")

View File

@@ -37,6 +37,35 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
DOCUMENT_ID,
CHUNK_ID,
TITLE,
TITLE_EMBEDDING,
CONTENT,
EMBEDDINGS,
SOURCE_TYPE,
METADATA_LIST,
DOC_UPDATED_AT,
HIDDEN,
BOOST,
SEMANTIC_IDENTIFIER,
IMAGE_FILE_NAME,
SOURCE_LINKS,
BLURB,
DOC_SUMMARY,
CHUNK_CONTEXT,
METADATA_SUFFIX,
DOCUMENT_SETS,
USER_PROJECT,
PRIMARY_OWNERS,
SECONDARY_OWNERS,
ACCESS_CONTROL_LIST,
]
if MULTI_TENANT:
FIELDS_NEEDED_FOR_TRANSFORMATION.append(TENANT_ID)
def _extract_content_vector(embeddings: Any) -> list[float]:
"""Extracts the full chunk embedding vector from Vespa's embeddings tensor.

View File

@@ -1,8 +0,0 @@
"""Celery tasks for connector pruning."""
from onyx.background.celery.tasks.pruning.tasks import check_for_pruning # noqa: F401
from onyx.background.celery.tasks.pruning.tasks import ( # noqa: F401
connector_pruning_generator_task,
)
__all__ = ["check_for_pruning", "connector_pruning_generator_task"]

View File

@@ -13,6 +13,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import DISABLE_VECTOR_DB
@@ -21,12 +22,14 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
@@ -57,6 +60,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@@ -120,7 +134,24 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Uses direct Redis locks to avoid overlapping runs.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -135,7 +166,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -148,12 +193,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -161,7 +229,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
)
return None
@@ -304,6 +373,12 @@ def process_single_user_file(
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,

View File

@@ -1,3 +1,4 @@
import json
import re
from collections.abc import Callable
from typing import cast
@@ -45,6 +46,7 @@ from onyx.utils.timing import log_function_time
logger = setup_logger()
IMAGE_GENERATION_TOOL_NAME = "generate_image"
def create_chat_session_from_request(
@@ -422,6 +424,40 @@ def convert_chat_history_basic(
return list(reversed(trimmed_reversed))
def _build_tool_call_response_history_message(
tool_name: str,
generated_images: list[dict] | None,
tool_call_response: str | None,
) -> str:
if tool_name != IMAGE_GENERATION_TOOL_NAME:
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
if generated_images:
llm_image_context: list[dict[str, str]] = []
for image in generated_images:
file_id = image.get("file_id")
revised_prompt = image.get("revised_prompt")
if not isinstance(file_id, str):
continue
llm_image_context.append(
{
"file_id": file_id,
"revised_prompt": (
revised_prompt if isinstance(revised_prompt, str) else ""
),
}
)
if llm_image_context:
return json.dumps(llm_image_context)
if tool_call_response:
return tool_call_response
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
def convert_chat_history(
chat_history: list[ChatMessage],
files: list[ChatLoadedFile],
@@ -582,10 +618,24 @@ def convert_chat_history(
# Add TOOL_CALL_RESPONSE messages for each tool call in this turn
for tool_call in turn_tool_calls:
tool_name = tool_id_to_name_map.get(
tool_call.tool_id, "unknown"
)
tool_response_message = (
_build_tool_call_response_history_message(
tool_name=tool_name,
generated_images=tool_call.generated_images,
tool_call_response=tool_call.tool_call_response,
)
)
simple_messages.append(
ChatMessageSimple(
message=TOOL_CALL_RESPONSE_CROSS_MESSAGE,
token_count=20, # Tiny overestimate
message=tool_response_message,
token_count=(
token_counter(tool_response_message)
if tool_name == IMAGE_GENERATION_TOOL_NAME
else 20
),
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,

View File

@@ -57,6 +57,7 @@ from onyx.tools.tool_implementations.images.models import (
FinalImageGenerationResponse,
)
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
@@ -651,6 +652,7 @@ def run_llm_loop(
ran_image_gen: bool = False
just_ran_web_search: bool = False
has_called_search_tool: bool = False
code_interpreter_file_generated: bool = False
fallback_extraction_attempted: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
@@ -761,6 +763,7 @@ def run_llm_loop(
),
include_citation_reminder=should_cite_documents
or always_cite_documents,
include_file_reminder=code_interpreter_file_generated,
is_last_cycle=out_of_cycles,
)
@@ -900,6 +903,18 @@ def run_llm_loop(
if tool_call.tool_name == SearchTool.NAME:
has_called_search_tool = True
# Track if code interpreter generated files with download links
if (
tool_call.tool_name == PythonTool.NAME
and not code_interpreter_file_generated
):
try:
parsed = json.loads(tool_response.llm_facing_response)
if parsed.get("generated_files"):
code_interpreter_file_generated = True
except (json.JSONDecodeError, AttributeError):
pass
# Build a mapping of tool names to tool objects for getting tool_id
tools_by_name = {tool.name: tool for tool in final_tools}

View File

@@ -10,6 +10,7 @@ from onyx.db.user_file import calculate_user_files_token_count
from onyx.file_store.models import FileDescriptor
from onyx.prompts.chat_prompts import CITATION_REMINDER
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
from onyx.prompts.chat_prompts import FILE_REMINDER
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
from onyx.prompts.prompt_utils import get_company_context
@@ -125,6 +126,7 @@ def calculate_reserved_tokens(
def build_reminder_message(
reminder_text: str | None,
include_citation_reminder: bool,
include_file_reminder: bool,
is_last_cycle: bool,
) -> str | None:
reminder = reminder_text.strip() if reminder_text else ""
@@ -132,6 +134,8 @@ def build_reminder_message(
reminder += "\n\n" + LAST_CYCLE_CITATION_REMINDER
if include_citation_reminder:
reminder += "\n\n" + CITATION_REMINDER
if include_file_reminder:
reminder += "\n\n" + FILE_REMINDER
reminder = reminder.strip()
return reminder if reminder else None
@@ -186,7 +190,7 @@ def _build_user_information_section(
if not sections:
return ""
return USER_INFORMATION_HEADER + "".join(sections)
return USER_INFORMATION_HEADER + "\n".join(sections)
def build_system_prompt(
@@ -224,23 +228,21 @@ def build_system_prompt(
system_prompt += REQUIRE_CITATION_GUIDANCE
if include_all_guidance:
system_prompt += (
TOOL_SECTION_HEADER
+ TOOL_DESCRIPTION_SEARCH_GUIDANCE
+ INTERNAL_SEARCH_GUIDANCE
+ WEB_SEARCH_GUIDANCE.format(
tool_sections = [
TOOL_DESCRIPTION_SEARCH_GUIDANCE,
INTERNAL_SEARCH_GUIDANCE,
WEB_SEARCH_GUIDANCE.format(
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
)
+ OPEN_URLS_GUIDANCE
+ PYTHON_TOOL_GUIDANCE
+ GENERATE_IMAGE_GUIDANCE
+ MEMORY_GUIDANCE
)
),
OPEN_URLS_GUIDANCE,
PYTHON_TOOL_GUIDANCE,
GENERATE_IMAGE_GUIDANCE,
MEMORY_GUIDANCE,
]
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_sections)
return system_prompt
if tools:
system_prompt += TOOL_SECTION_HEADER
has_web_search = any(isinstance(tool, WebSearchTool) for tool in tools)
has_internal_search = any(isinstance(tool, SearchTool) for tool in tools)
has_open_urls = any(isinstance(tool, OpenURLTool) for tool in tools)
@@ -250,12 +252,14 @@ def build_system_prompt(
)
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
tool_guidance_sections: list[str] = []
if has_web_search or has_internal_search or include_all_guidance:
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
tool_guidance_sections.append(TOOL_DESCRIPTION_SEARCH_GUIDANCE)
# These are not included at the Tool level because the ordering may matter.
if has_internal_search or include_all_guidance:
system_prompt += INTERNAL_SEARCH_GUIDANCE
tool_guidance_sections.append(INTERNAL_SEARCH_GUIDANCE)
if has_web_search or include_all_guidance:
site_disabled_guidance = ""
@@ -265,20 +269,23 @@ def build_system_prompt(
)
if web_search_tool and not web_search_tool.supports_site_filter:
site_disabled_guidance = WEB_SEARCH_SITE_DISABLED_GUIDANCE
system_prompt += WEB_SEARCH_GUIDANCE.format(
site_colon_disabled=site_disabled_guidance
tool_guidance_sections.append(
WEB_SEARCH_GUIDANCE.format(site_colon_disabled=site_disabled_guidance)
)
if has_open_urls or include_all_guidance:
system_prompt += OPEN_URLS_GUIDANCE
tool_guidance_sections.append(OPEN_URLS_GUIDANCE)
if has_python or include_all_guidance:
system_prompt += PYTHON_TOOL_GUIDANCE
tool_guidance_sections.append(PYTHON_TOOL_GUIDANCE)
if has_generate_image or include_all_guidance:
system_prompt += GENERATE_IMAGE_GUIDANCE
tool_guidance_sections.append(GENERATE_IMAGE_GUIDANCE)
if has_memory or include_all_guidance:
system_prompt += MEMORY_GUIDANCE
tool_guidance_sections.append(MEMORY_GUIDANCE)
if tool_guidance_sections:
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_guidance_sections)
return system_prompt

View File

@@ -251,7 +251,9 @@ DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int(
os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50
)
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
OPENSEARCH_ADMIN_PASSWORD = os.environ.get(
"OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!"
)
USING_AWS_MANAGED_OPENSEARCH = (
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
)
@@ -263,6 +265,18 @@ OPENSEARCH_PROFILING_DISABLED = (
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
)
# When enabled, OpenSearch returns detailed score breakdowns for each hit.
# Useful for debugging and tuning search relevance. Has ~10-30% performance overhead according to documentation.
# Seems for Hybrid Search in practice, the impact is actually more like 1000x slower.
OPENSEARCH_EXPLAIN_ENABLED = (
os.environ.get("OPENSEARCH_EXPLAIN_ENABLED", "").lower() == "true"
)
# Analyzer used for full-text fields (title, content). Use OpenSearch built-in analyzer
# names (e.g. "english", "standard", "german"). Affects stemming and tokenization;
# existing indices need reindexing after a change.
OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "english"
# This is the "base" config for now, the idea is that at least for our dev
# environments we always want to be dual indexing into both OpenSearch and Vespa
# to stress test the new codepaths. Only enable this if there is some instance
@@ -270,6 +284,9 @@ OPENSEARCH_PROFILING_DISABLED = (
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
)
# NOTE: This effectively does nothing anymore, admins can now toggle whether
# retrieval is through OpenSearch. This value is only used as a final fallback
# in case that doesn't work for whatever reason.
# Given that the "base" config above is true, this enables whether we want to
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
# in the event we see issues with OpenSearch retrieval in our dev environments.

View File

@@ -157,6 +157,17 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
# How long a queued user-file task is valid before workers discard it.
# Should be longer than the beat interval (20 s) but short enough to prevent
# indefinite queue growth. Workers drop tasks older than this without touching
# the DB, so a shorter value = faster drain of stale duplicates.
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Maximum number of tasks allowed in the user-file-processing queue before the
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
@@ -443,6 +454,9 @@ class OnyxRedisLocks:
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a task is already queued.
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"

View File

@@ -71,6 +71,7 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
SLIM_BATCH_SIZE = 1000
_EPOCH = datetime.fromtimestamp(0, tz=timezone.utc)
SHARED_DOCUMENTS_MAP = {
@@ -243,6 +244,12 @@ class SharepointConnectorCheckpoint(ConnectorCheckpoint):
current_drive_name: str | None = None
# Drive's web_url from the API - used as raw_node_id for DRIVE hierarchy nodes
current_drive_web_url: str | None = None
# Resolved drive ID — avoids re-resolving on checkpoint resume
current_drive_id: str | None = None
# Next delta API page URL for per-page checkpointing within a drive.
# When set, Phase 3b fetches one page at a time so progress is persisted
# between pages. None means BFS path or no active delta traversal.
current_drive_delta_next_link: str | None = None
process_site_pages: bool = False
@@ -1266,7 +1273,8 @@ class SharepointConnector(
"""
base = f"{self.graph_api_base}/drives/{drive_id}"
if folder_path:
start_url = f"{base}/root:/{folder_path}:/children"
encoded_path = quote(folder_path, safe="/")
start_url = f"{base}/root:/{encoded_path}:/children"
else:
start_url = f"{base}/root/children"
@@ -1322,13 +1330,12 @@ class SharepointConnector(
Falls back to full enumeration if the API returns 410 Gone (expired token).
"""
EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc)
use_timestamp_token = start is not None and start > EPOCH
use_timestamp_token = start is not None and start > _EPOCH
initial_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta"
if use_timestamp_token:
assert start is not None # purely for mypy
token = quote(start.strftime("%Y-%m-%dT%H:%M:%SZ"))
assert start is not None # mypy
token = quote(start.isoformat(timespec="seconds"))
initial_url += f"?token={token}"
yield from self._iter_delta_pages(
@@ -1361,6 +1368,7 @@ class SharepointConnector(
try:
data = self._graph_api_get_json(page_url, params)
except requests.HTTPError as e:
# 410 means the delta token expired, so we need to fall back to full enumeration
if e.response is not None and e.response.status_code == 410:
if not allow_full_resync:
raise
@@ -1397,10 +1405,91 @@ class SharepointConnector(
yield DriveItemData.from_graph_json(item)
page_url = data.get("@odata.nextLink") or data.get("@odata.deltaLink")
if "@odata.deltaLink" in data and "@odata.nextLink" not in data:
page_url = data.get("@odata.nextLink")
if not page_url:
break
def _build_delta_start_url(
self,
drive_id: str,
start: datetime | None = None,
page_size: int = 200,
) -> str:
"""Build the initial delta API URL with query parameters embedded.
Embeds ``$top`` (and optionally a timestamp ``token``) directly in the
URL so that the returned string is fully self-contained and can be
stored in a checkpoint without needing a separate params dict.
"""
base_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta"
params = [f"$top={page_size}"]
if start is not None and start > _EPOCH:
token = quote(start.isoformat(timespec="seconds"))
params.append(f"token={token}")
return f"{base_url}?{'&'.join(params)}"
def _fetch_one_delta_page(
self,
page_url: str,
drive_id: str,
start: datetime | None = None,
end: datetime | None = None,
page_size: int = 200,
) -> tuple[list[DriveItemData], str | None]:
"""Fetch a single page of delta API results.
Returns ``(items, next_page_url)``. *next_page_url* is ``None`` when
the delta enumeration is complete (deltaLink with no nextLink).
On 410 Gone (expired token) returns ``([], full_resync_url)`` so
the caller can store the resync URL in the checkpoint and retry on
the next cycle.
"""
try:
data = self._graph_api_get_json(page_url)
except requests.HTTPError as e:
if e.response is not None and e.response.status_code == 410:
logger.warning(
"Delta token expired (410 Gone) for drive '%s'. "
"Will restart with full delta enumeration.",
drive_id,
)
full_url = (
f"{self.graph_api_base}/drives/{drive_id}/root/delta"
f"?$top={page_size}"
)
return [], full_url
raise
items: list[DriveItemData] = []
for item in data.get("value", []):
if "folder" in item or "deleted" in item:
continue
if start is not None or end is not None:
raw_ts = item.get("lastModifiedDateTime")
if raw_ts:
mod_dt = datetime.fromisoformat(raw_ts.replace("Z", "+00:00"))
if start is not None and mod_dt < start:
continue
if end is not None and mod_dt > end:
continue
items.append(DriveItemData.from_graph_json(item))
next_url = data.get("@odata.nextLink")
if next_url:
return items, next_url
return items, None
@staticmethod
def _clear_drive_checkpoint_state(
checkpoint: "SharepointConnectorCheckpoint",
) -> None:
"""Reset all drive-level fields in the checkpoint."""
checkpoint.current_drive_name = None
checkpoint.current_drive_id = None
checkpoint.current_drive_web_url = None
checkpoint.current_drive_delta_next_link = None
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
site_descriptors = self.site_descriptors or self.fetch_sites()
@@ -1842,14 +1931,13 @@ class SharepointConnector(
# Return checkpoint to allow persistence after drive initialization
return checkpoint
# Phase 3: Process documents from current drive
# Phase 3a: Initialize the next drive for processing
if (
checkpoint.current_site_descriptor
and checkpoint.cached_drive_names
and len(checkpoint.cached_drive_names) > 0
and checkpoint.current_drive_name is None
):
checkpoint.current_drive_name = checkpoint.cached_drive_names.popleft()
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
@@ -1857,7 +1945,8 @@ class SharepointConnector(
site_descriptor = checkpoint.current_site_descriptor
logger.info(
f"Processing drive '{checkpoint.current_drive_name}' in site: {site_descriptor.url}"
f"Processing drive '{checkpoint.current_drive_name}' "
f"in site: {site_descriptor.url}"
)
logger.debug(f"Time range: {start_dt} to {end_dt}")
@@ -1866,35 +1955,35 @@ class SharepointConnector(
logger.warning("Current drive name is None, skipping")
return checkpoint
driveitems: Iterable[DriveItemData] = iter(())
drive_web_url: str | None = None
try:
logger.info(
f"Fetching drive items for drive name: {current_drive_name}"
)
result = self._resolve_drive(site_descriptor, current_drive_name)
if result is not None:
drive_id, drive_web_url = result
driveitems = self._get_drive_items_for_drive_id(
site_descriptor, drive_id, start_dt, end_dt
)
checkpoint.current_drive_web_url = drive_web_url
if result is None:
logger.warning(f"Drive '{current_drive_name}' not found, skipping")
self._clear_drive_checkpoint_state(checkpoint)
return checkpoint
drive_id, drive_web_url = result
checkpoint.current_drive_id = drive_id
checkpoint.current_drive_web_url = drive_web_url
except Exception as e:
logger.error(
f"Failed to retrieve items from drive '{current_drive_name}' in site: {site_descriptor.url}: {e}"
f"Failed to retrieve items from drive '{current_drive_name}' "
f"in site: {site_descriptor.url}: {e}"
)
yield _create_entity_failure(
f"{site_descriptor.url}|{current_drive_name}",
f"Failed to access drive '{current_drive_name}' in site '{site_descriptor.url}': {str(e)}",
f"Failed to access drive '{current_drive_name}' "
f"in site '{site_descriptor.url}': {str(e)}",
(start_dt, end_dt),
e,
)
checkpoint.current_drive_name = None
checkpoint.current_drive_web_url = None
self._clear_drive_checkpoint_state(checkpoint)
return checkpoint
# Normalize drive name (e.g., "Documents" -> "Shared Documents")
current_drive_name = SHARED_DOCUMENTS_MAP.get(
display_drive_name = SHARED_DOCUMENTS_MAP.get(
current_drive_name, current_drive_name
)
@@ -1902,10 +1991,74 @@ class SharepointConnector(
yield from self._yield_drive_hierarchy_node(
site_descriptor.url,
drive_web_url,
current_drive_name,
display_drive_name,
checkpoint,
)
# For non-folder-scoped drives, use delta API with per-page
# checkpointing. Build the initial URL and fall through to 3b.
if not site_descriptor.folder_path:
checkpoint.current_drive_delta_next_link = self._build_delta_start_url(
drive_id, start_dt
)
# else: BFS path — delta_next_link stays None;
# Phase 3b will use _iter_drive_items_paged.
# Phase 3b: Process items from the current drive
if (
checkpoint.current_site_descriptor
and checkpoint.current_drive_name is not None
and checkpoint.current_drive_id is not None
):
site_descriptor = checkpoint.current_site_descriptor
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
end_dt = datetime.fromtimestamp(end, tz=timezone.utc)
current_drive_name = SHARED_DOCUMENTS_MAP.get(
checkpoint.current_drive_name, checkpoint.current_drive_name
)
drive_web_url = checkpoint.current_drive_web_url
# --- determine item source ---
driveitems: Iterable[DriveItemData]
has_more_delta_pages = False
if checkpoint.current_drive_delta_next_link:
# Delta path: fetch one page at a time for checkpointing
try:
page_items, next_url = self._fetch_one_delta_page(
page_url=checkpoint.current_drive_delta_next_link,
drive_id=checkpoint.current_drive_id,
start=start_dt,
end=end_dt,
)
except Exception as e:
logger.error(
f"Failed to fetch delta page for drive "
f"'{current_drive_name}': {e}"
)
yield _create_entity_failure(
f"{site_descriptor.url}|{current_drive_name}",
f"Failed to fetch delta page for drive "
f"'{current_drive_name}': {str(e)}",
(start_dt, end_dt),
e,
)
self._clear_drive_checkpoint_state(checkpoint)
return checkpoint
driveitems = page_items
has_more_delta_pages = next_url is not None
if next_url:
checkpoint.current_drive_delta_next_link = next_url
else:
# BFS path (folder-scoped): process all items at once
driveitems = self._iter_drive_items_paged(
drive_id=checkpoint.current_drive_id,
folder_path=site_descriptor.folder_path,
start=start_dt,
end=end_dt,
)
item_count = 0
for driveitem in driveitems:
item_count += 1
@@ -1947,8 +2100,6 @@ class SharepointConnector(
if include_permissions:
ctx = self._create_rest_client_context(site_descriptor.url)
# Re-acquire token in case it expired during a long traversal
# MSAL has a cache that returns the same token while still valid.
access_token = self._get_graph_access_token()
doc_or_failure = _convert_driveitem_to_document_with_permissions(
driveitem,
@@ -1984,8 +2135,11 @@ class SharepointConnector(
)
logger.info(f"Processed {item_count} items in drive '{current_drive_name}'")
checkpoint.current_drive_name = None
checkpoint.current_drive_web_url = None
if has_more_delta_pages:
return checkpoint
self._clear_drive_checkpoint_state(checkpoint)
# Phase 4: Progression logic - determine next step
# If we have more drives in current site, continue with current site

View File

@@ -32,6 +32,7 @@ from onyx.context.search.federated.slack_search_utils import should_include_mess
from onyx.context.search.models import ChunkIndexRequest
from onyx.context.search.models import InferenceChunk
from onyx.db.document import DocumentSource
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.document_index_utils import (
get_multipass_config,
@@ -905,13 +906,15 @@ def convert_slack_score(slack_score: float) -> float:
def slack_retrieval(
query: ChunkIndexRequest,
access_token: str,
db_session: Session,
db_session: Session | None = None,
connector: FederatedConnectorDetail | None = None, # noqa: ARG001
entities: dict[str, Any] | None = None,
limit: int | None = None,
slack_event_context: SlackContext | None = None,
bot_token: str | None = None, # Add bot token parameter
team_id: str | None = None,
# Pre-fetched data — when provided, avoids DB query (no session needed)
search_settings: SearchSettings | None = None,
) -> list[InferenceChunk]:
"""
Main entry point for Slack federated search with entity filtering.
@@ -925,7 +928,7 @@ def slack_retrieval(
Args:
query: Search query object
access_token: User OAuth access token
db_session: Database session
db_session: Database session (optional if search_settings provided)
connector: Federated connector detail (unused, kept for backwards compat)
entities: Connector-level config (entity filtering configuration)
limit: Maximum number of results
@@ -1153,7 +1156,10 @@ def slack_retrieval(
# chunk index docs into doc aware chunks
# a single index doc can get split into multiple chunks
search_settings = get_current_search_settings(db_session)
if search_settings is None:
if db_session is None:
raise ValueError("Either db_session or search_settings must be provided")
search_settings = get_current_search_settings(db_session)
embedder = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
)

View File

@@ -18,8 +18,10 @@ from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.llm.interfaces import LLM
from onyx.natural_language_processing.english_stopwords import strip_stopwords
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.secondary_llm_flows.source_filter import extract_source_filter
from onyx.secondary_llm_flows.time_filter import extract_time_filter
from onyx.utils.logger import setup_logger
@@ -41,7 +43,7 @@ def _build_index_filters(
user_file_ids: list[UUID] | None,
persona_document_sets: list[str] | None,
persona_time_cutoff: datetime | None,
db_session: Session,
db_session: Session | None = None,
auto_detect_filters: bool = False,
query: str | None = None,
llm: LLM | None = None,
@@ -49,18 +51,19 @@ def _build_index_filters(
# Assistant knowledge filters
attached_document_ids: list[str] | None = None,
hierarchy_node_ids: list[int] | None = None,
# Pre-fetched ACL filters (skips DB query when provided)
acl_filters: list[str] | None = None,
) -> IndexFilters:
if auto_detect_filters and (llm is None or query is None):
raise RuntimeError("LLM and query are required for auto detect filters")
base_filters = user_provided_filters or BaseFilters()
if (
user_provided_filters
and user_provided_filters.document_set is None
and persona_document_sets is not None
):
base_filters.document_set = persona_document_sets
document_set_filter = (
base_filters.document_set
if base_filters.document_set is not None
else persona_document_sets
)
time_filter = base_filters.time_cutoff or persona_time_cutoff
source_filter = base_filters.source_type
@@ -103,15 +106,20 @@ def _build_index_filters(
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
logger.debug("Added USER_FILE to source_filter for user knowledge search")
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
if bypass_acl:
user_acl_filters = None
elif acl_filters is not None:
user_acl_filters = acl_filters
else:
if db_session is None:
raise ValueError("Either db_session or acl_filters must be provided")
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = IndexFilters(
user_file_ids=user_file_ids,
project_id=project_id,
source_type=source_filter,
document_set=persona_document_sets,
document_set=document_set_filter,
time_cutoff=time_filter,
tags=base_filters.tags,
access_control_list=user_acl_filters,
@@ -252,11 +260,15 @@ def search_pipeline(
user: User,
# Used for default filters and settings
persona: Persona | None,
db_session: Session,
db_session: Session | None = None,
auto_detect_filters: bool = False,
llm: LLM | None = None,
# If a project ID is provided, it will be exclusively scoped to that project
project_id: int | None = None,
# Pre-fetched data — when provided, avoids DB queries (no session needed)
acl_filters: list[str] | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
) -> list[InferenceChunk]:
user_uploaded_persona_files: list[UUID] | None = (
[user_file.id for user_file in persona.user_files] if persona else None
@@ -297,6 +309,7 @@ def search_pipeline(
bypass_acl=chunk_search_request.bypass_acl,
attached_document_ids=attached_document_ids,
hierarchy_node_ids=hierarchy_node_ids,
acl_filters=acl_filters,
)
query_keywords = strip_stopwords(chunk_search_request.query)
@@ -315,6 +328,8 @@ def search_pipeline(
user_id=user.id if user else None,
document_index=document_index,
db_session=db_session,
embedding_model=embedding_model,
prefetched_federated_retrieval_infos=prefetched_federated_retrieval_infos,
)
# For some specific connectors like Salesforce, a user that has access to an object doesn't mean

View File

@@ -14,9 +14,11 @@ from onyx.context.search.utils import get_query_embedding
from onyx.context.search.utils import inference_section_from_chunks
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.federated_connectors.federated_retrieval import (
get_federated_retrieval_functions,
)
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -50,9 +52,14 @@ def combine_retrieval_results(
def _embed_and_search(
query_request: ChunkIndexRequest,
document_index: DocumentIndex,
db_session: Session,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> list[InferenceChunk]:
query_embedding = get_query_embedding(query_request.query, db_session)
query_embedding = get_query_embedding(
query_request.query,
db_session=db_session,
embedding_model=embedding_model,
)
hybrid_alpha = query_request.hybrid_alpha or HYBRID_ALPHA
@@ -78,7 +85,9 @@ def search_chunks(
query_request: ChunkIndexRequest,
user_id: UUID | None,
document_index: DocumentIndex,
db_session: Session,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
) -> list[InferenceChunk]:
run_queries: list[tuple[Callable, tuple]] = []
@@ -88,14 +97,22 @@ def search_chunks(
else None
)
# Federated retrieval
federated_retrieval_infos = get_federated_retrieval_functions(
db_session=db_session,
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
# Federated retrieval — use pre-fetched if available, otherwise query DB
if prefetched_federated_retrieval_infos is not None:
federated_retrieval_infos = prefetched_federated_retrieval_infos
else:
if db_session is None:
raise ValueError(
"Either db_session or prefetched_federated_retrieval_infos "
"must be provided"
)
federated_retrieval_infos = get_federated_retrieval_functions(
db_session=db_session,
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
federated_sources = set(
federated_retrieval_info.source.to_non_federated_source()
@@ -114,7 +131,10 @@ def search_chunks(
if normal_search_enabled:
run_queries.append(
(_embed_and_search, (query_request, document_index, db_session))
(
_embed_and_search,
(query_request, document_index, db_session, embedding_model),
)
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)

View File

@@ -64,23 +64,34 @@ def inference_section_from_single_chunk(
)
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
search_settings = get_current_search_settings(db_session)
def get_query_embeddings(
queries: list[str],
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> list[Embedding]:
if embedding_model is None:
if db_session is None:
raise ValueError("Either db_session or embedding_model must be provided")
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
query_embedding = embedding_model.encode(queries, text_type=EmbedTextType.QUERY)
return query_embedding
@log_function_time(print_only=True, debug_only=True)
def get_query_embedding(query: str, db_session: Session) -> Embedding:
return get_query_embeddings([query], db_session)[0]
def get_query_embedding(
query: str,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> Embedding:
return get_query_embeddings(
[query], db_session=db_session, embedding_model=embedding_model
)[0]
def convert_inference_sections_to_search_docs(

View File

@@ -4,6 +4,7 @@ from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.api_key import ApiKeyDescriptor
@@ -54,6 +55,7 @@ async def fetch_user_for_api_key(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
.options(selectinload(User.memories))
)

View File

@@ -13,6 +13,7 @@ from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
@@ -97,6 +98,11 @@ async def get_user_count(only_admin_users: bool = False) -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def _get_user(self, statement: Select) -> UP | None:
statement = statement.options(selectinload(User.memories))
results = await self.session.execute(statement)
return results.unique().scalar_one_or_none()
async def create(
self,
create_dict: Dict[str, Any],

View File

@@ -116,12 +116,15 @@ def get_connector_credential_pairs_for_user(
order_by_desc: bool = False,
source: DocumentSource | None = None,
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
defer_connector_config: bool = False,
) -> list[ConnectorCredentialPair]:
"""Get connector credential pairs for a user.
Args:
processing_mode: Filter by processing mode. Defaults to REGULAR to hide
FILE_SYSTEM connectors from standard admin UI. Pass None to get all.
defer_connector_config: If True, skips loading Connector.connector_specific_config
to avoid fetching large JSONB blobs when they aren't needed.
"""
if eager_load_user:
assert (
@@ -130,7 +133,10 @@ def get_connector_credential_pairs_for_user(
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
connector_load = selectinload(ConnectorCredentialPair.connector)
if defer_connector_config:
connector_load = connector_load.defer(Connector.connector_specific_config)
stmt = stmt.options(connector_load)
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
@@ -170,6 +176,7 @@ def get_connector_credential_pairs_for_user_parallel(
order_by_desc: bool = False,
source: DocumentSource | None = None,
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
defer_connector_config: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_with_current_tenant() as db_session:
return get_connector_credential_pairs_for_user(
@@ -183,6 +190,7 @@ def get_connector_credential_pairs_for_user_parallel(
order_by_desc=order_by_desc,
source=source,
processing_mode=processing_mode,
defer_connector_config=defer_connector_config,
)

View File

@@ -554,10 +554,19 @@ def fetch_all_document_sets_for_user(
stmt = (
select(DocumentSetDBModel)
.distinct()
.options(selectinload(DocumentSetDBModel.federated_connectors))
.options(
selectinload(DocumentSetDBModel.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSetDBModel.users),
selectinload(DocumentSetDBModel.groups),
selectinload(DocumentSetDBModel.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
)
)
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
return db_session.scalars(stmt).all()
return db_session.scalars(stmt).unique().all()
def fetch_documents_for_document_set_paginated(

View File

@@ -1,11 +1,102 @@
from sqlalchemy import text
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import SqlEngine
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
def get_schemas_needing_migration(
tenant_schemas: list[str], head_rev: str
) -> list[str]:
"""Return only schemas whose current alembic version is not at head.
Uses a server-side PL/pgSQL loop to collect each schema's alembic version
into a temp table one at a time. This avoids building a massive UNION ALL
query (which locks the DB and times out at 17k+ schemas) and instead
acquires locks sequentially, one schema per iteration.
"""
if not tenant_schemas:
return []
engine = SqlEngine.get_engine()
with engine.connect() as conn:
# Populate a temp input table with exactly the schemas we care about.
# The DO block reads from this table so it only iterates the requested
# schemas instead of every tenant_% schema in the database.
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
conn.execute(text("CREATE TEMP TABLE _tenant_schemas_input (schema_name text)"))
conn.execute(
text(
"INSERT INTO _tenant_schemas_input (schema_name) "
"SELECT unnest(CAST(:schemas AS text[]))"
),
{"schemas": tenant_schemas},
)
conn.execute(
text(
"CREATE TEMP TABLE _alembic_version_snapshot "
"(schema_name text, version_num text)"
)
)
conn.execute(
text(
"""
DO $$
DECLARE
s text;
schemas text[];
BEGIN
SELECT array_agg(schema_name) INTO schemas
FROM _tenant_schemas_input;
IF schemas IS NULL THEN
RAISE NOTICE 'No tenant schemas found.';
RETURN;
END IF;
FOREACH s IN ARRAY schemas LOOP
BEGIN
EXECUTE format(
'INSERT INTO _alembic_version_snapshot
SELECT %L, version_num FROM %I.alembic_version',
s, s
);
EXCEPTION
-- undefined_table: schema exists but has no alembic_version
-- table yet (new tenant, not yet migrated).
-- invalid_schema_name: tenant is registered but its
-- PostgreSQL schema does not exist yet (e.g. provisioning
-- incomplete). Both cases mean no version is available and
-- the schema will be included in the migration list.
WHEN undefined_table THEN NULL;
WHEN invalid_schema_name THEN NULL;
END;
END LOOP;
END;
$$
"""
)
)
rows = conn.execute(
text("SELECT schema_name, version_num FROM _alembic_version_snapshot")
)
version_by_schema = {row[0]: row[1] for row in rows}
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
# Schemas missing from the snapshot have no alembic_version table yet and
# also need migration. version_by_schema.get(s) returns None for those,
# and None != head_rev, so they are included automatically.
return [s for s in tenant_schemas if version_by_schema.get(s) != head_rev]
def get_all_tenant_ids() -> list[str]:
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""

View File

@@ -232,6 +232,12 @@ class BuildSessionStatus(str, PyEnum):
IDLE = "idle"
class SharingScope(str, PyEnum):
PRIVATE = "private"
PUBLIC_ORG = "public_org"
PUBLIC_GLOBAL = "public_global"
class SandboxStatus(str, PyEnum):
PROVISIONING = "provisioning"
RUNNING = "running"

View File

@@ -77,6 +77,7 @@ from onyx.db.enums import (
ThemePreference,
DefaultAppMode,
SwitchoverType,
SharingScope,
)
from onyx.configs.constants import NotificationType
from onyx.configs.constants import SearchFeedbackType
@@ -286,7 +287,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user", lazy="joined"
"Credential", back_populates="user"
)
chat_sessions: Mapped[list["ChatSession"]] = relationship(
"ChatSession", back_populates="user"
@@ -320,7 +321,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
"Memory",
back_populates="user",
cascade="all, delete-orphan",
lazy="selectin",
order_by="desc(Memory.id)",
)
oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
@@ -1040,7 +1040,9 @@ class OpenSearchTenantMigrationRecord(Base):
nullable=False,
)
# Opaque continuation token from Vespa's Visit API.
# NULL means "not started" or "visit completed".
# NULL means "not started".
# Otherwise contains a serialized mapping between slice ID and continuation
# token for that slice.
vespa_visit_continuation_token: Mapped[str | None] = mapped_column(
Text, nullable=True
)
@@ -1064,6 +1066,9 @@ class OpenSearchTenantMigrationRecord(Base):
enable_opensearch_retrieval: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
approx_chunk_count_in_vespa: Mapped[int | None] = mapped_column(
Integer, nullable=True
)
class KGEntityType(Base):
@@ -4712,6 +4717,12 @@ class BuildSession(Base):
demo_data_enabled: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default=text("true")
)
sharing_scope: Mapped[SharingScope] = mapped_column(
String,
nullable=False,
default=SharingScope.PRIVATE,
server_default="private",
)
# Relationships
user: Mapped[User | None] = relationship("User", foreign_keys=[user_id])
@@ -4928,6 +4939,7 @@ class ScimUserMapping(Base):
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
scim_username: Mapped[str | None] = mapped_column(String, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
@@ -4966,3 +4978,12 @@ class ScimGroupMapping(Base):
user_group: Mapped[UserGroup] = relationship(
"UserGroup", foreign_keys=[user_group_id]
)
class CodeInterpreterServer(Base):
"""Details about the code interpreter server"""
__tablename__ = "code_interpreter_server"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)

View File

@@ -4,6 +4,7 @@ This module provides functions to track the progress of migrating documents
from Vespa to OpenSearch.
"""
import json
from datetime import datetime
from datetime import timezone
@@ -12,6 +13,9 @@ from sqlalchemy import text
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.opensearch_migration.constants import (
GET_VESPA_CHUNKS_SLICE_COUNT,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
)
@@ -243,29 +247,37 @@ def should_document_migration_be_permanently_failed(
def get_vespa_visit_state(
db_session: Session,
) -> tuple[str | None, int]:
) -> tuple[dict[int, str | None], int]:
"""Gets the current Vespa migration state from the tenant migration record.
Requires the OpenSearchTenantMigrationRecord to exist.
Returns:
Tuple of (continuation_token, total_chunks_migrated). continuation_token
is None if not started or completed.
Tuple of (continuation_token_map, total_chunks_migrated).
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
return (
record.vespa_visit_continuation_token,
record.total_chunks_migrated,
)
if record.vespa_visit_continuation_token is None:
continuation_token_map: dict[int, str | None] = {
slice_id: None for slice_id in range(GET_VESPA_CHUNKS_SLICE_COUNT)
}
else:
json_loaded_continuation_token_map = json.loads(
record.vespa_visit_continuation_token
)
continuation_token_map = {
int(key): value for key, value in json_loaded_continuation_token_map.items()
}
return continuation_token_map, record.total_chunks_migrated
def update_vespa_visit_progress_with_commit(
db_session: Session,
continuation_token: str | None,
continuation_token_map: dict[int, str | None],
chunks_processed: int,
chunks_errored: int,
approx_chunk_count_in_vespa: int | None,
) -> None:
"""Updates the Vespa migration progress and commits.
@@ -273,19 +285,26 @@ def update_vespa_visit_progress_with_commit(
Args:
db_session: SQLAlchemy session.
continuation_token: The new continuation token. None means the visit
is complete.
continuation_token_map: The new continuation token map. None entry means
the visit is complete for that slice.
chunks_processed: Number of chunks processed in this batch (added to
the running total).
chunks_errored: Number of chunks errored in this batch (added to the
running errored total).
approx_chunk_count_in_vespa: Approximate number of chunks in Vespa. If
None, the existing value is used.
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
record.vespa_visit_continuation_token = continuation_token
record.vespa_visit_continuation_token = json.dumps(continuation_token_map)
record.total_chunks_migrated += chunks_processed
record.total_chunks_errored += chunks_errored
record.approx_chunk_count_in_vespa = (
approx_chunk_count_in_vespa
if approx_chunk_count_in_vespa is not None
else record.approx_chunk_count_in_vespa
)
db_session.commit()
@@ -353,25 +372,27 @@ def build_sanitized_to_original_doc_id_mapping(
def get_opensearch_migration_state(
db_session: Session,
) -> tuple[int, datetime | None, datetime | None]:
) -> tuple[int, datetime | None, datetime | None, int | None]:
"""Returns the state of the Vespa to OpenSearch migration.
If the tenant migration record is not found, returns defaults of 0, None,
None.
None, None.
Args:
db_session: SQLAlchemy session.
Returns:
Tuple of (total_chunks_migrated, created_at, migration_completed_at).
Tuple of (total_chunks_migrated, created_at, migration_completed_at,
approx_chunk_count_in_vespa).
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
return 0, None, None
return 0, None, None, None
return (
record.total_chunks_migrated,
record.created_at,
record.migration_completed_at,
record.approx_chunk_count_in_vespa,
)

View File

@@ -8,6 +8,7 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.pat import build_displayable_pat
@@ -31,55 +32,61 @@ async def fetch_user_for_pat(
NOTE: This is async since it's used during auth (which is necessarily async due to FastAPI Users).
NOTE: Expired includes both naturally expired and user-revoked tokens (revocation sets expires_at=NOW()).
Uses select(User) as primary entity so that joined-eager relationships (e.g. oauth_accounts)
are loaded correctly — matching the pattern in fetch_user_for_api_key.
"""
# Single joined query with all filters pushed to database
now = datetime.now(timezone.utc)
result = await async_db_session.execute(
select(PersonalAccessToken, User)
.join(User, PersonalAccessToken.user_id == User.id)
user = await async_db_session.scalar(
select(User)
.join(PersonalAccessToken, PersonalAccessToken.user_id == User.id)
.where(PersonalAccessToken.hashed_token == hashed_token)
.where(User.is_active) # type: ignore
.where(
(PersonalAccessToken.expires_at.is_(None))
| (PersonalAccessToken.expires_at > now)
)
.limit(1)
.options(selectinload(User.memories))
)
row = result.first()
if not row:
if not user:
return None
pat, user = row
# Throttle last_used_at updates to reduce DB load (5-minute granularity sufficient for auditing)
# For request-level auditing, use application logs or a dedicated audit table
should_update = (
pat.last_used_at is None or (now - pat.last_used_at).total_seconds() > 300
)
if should_update:
# Update in separate session to avoid transaction coupling (fire-and-forget)
async def _update_last_used() -> None:
try:
tenant_id = get_current_tenant_id()
async with get_async_session_context_manager(
tenant_id
) as separate_session:
await separate_session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.hashed_token == hashed_token)
.values(last_used_at=now)
)
await separate_session.commit()
except Exception as e:
logger.warning(f"Failed to update last_used_at for PAT: {e}")
asyncio.create_task(_update_last_used())
_schedule_pat_last_used_update(hashed_token, now)
return user
def _schedule_pat_last_used_update(hashed_token: str, now: datetime) -> None:
"""Fire-and-forget update of last_used_at, throttled to 5-minute granularity."""
async def _update() -> None:
try:
tenant_id = get_current_tenant_id()
async with get_async_session_context_manager(tenant_id) as session:
pat = await session.scalar(
select(PersonalAccessToken).where(
PersonalAccessToken.hashed_token == hashed_token
)
)
if not pat:
return
if (
pat.last_used_at is not None
and (now - pat.last_used_at).total_seconds() <= 300
):
return
await session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.hashed_token == hashed_token)
.values(last_used_at=now)
)
await session.commit()
except Exception as e:
logger.warning(f"Failed to update last_used_at for PAT: {e}")
asyncio.create_task(_update())
def create_pat(
db_session: Session,
user_id: UUID,

View File

@@ -28,6 +28,7 @@ from onyx.db.document_access import get_accessible_documents_by_ids
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import HierarchyNode
from onyx.db.models import Persona
from onyx.db.models import Persona__User
@@ -420,9 +421,16 @@ def get_minimal_persona_snapshots_for_user(
stmt = stmt.options(
selectinload(Persona.tools),
selectinload(Persona.labels),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.connector),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
@@ -453,7 +461,16 @@ def get_persona_snapshots_for_user(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
@@ -550,9 +567,16 @@ def get_minimal_persona_snapshots_paginated(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.connector),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
)
@@ -611,7 +635,16 @@ def get_persona_snapshots_paginated(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),

View File

@@ -54,6 +54,9 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
# Maps schema property name to a list of highlighted snippets with match
# terms wrapped in tags (e.g. "something <hi>keyword</hi> other thing").
match_highlights: dict[str, list[str]] = {}
# Score explanation from OpenSearch when "explain": true is set in the query.
# Contains detailed breakdown of how the score was calculated.
explanation: dict[str, Any] | None = None
def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
@@ -706,10 +709,12 @@ class OpenSearchClient:
)
document_chunk_score = hit.get("_score", None)
match_highlights: dict[str, list[str]] = hit.get("highlight", {})
explanation: dict[str, Any] | None = hit.get("_explanation", None)
search_hit = SearchHit[DocumentChunk](
document_chunk=DocumentChunk.model_validate(document_chunk_source),
score=document_chunk_score,
match_highlights=match_highlights,
explanation=explanation,
)
search_hits.append(search_hit)
logger.debug(

View File

@@ -10,31 +10,31 @@ EF_CONSTRUCTION = 256
# quality but increase memory footprint. Values typically range between 12 - 48.
M = 32 # Set relatively high for better accuracy.
# Number of vectors to examine for top k neighbors for the HNSW method.
# Should be >= DEFAULT_K_NUM_CANDIDATES for good recall; higher = better accuracy, slower search.
# Bumped this to 1000, for dataset of low 10,000 docs, did not see improvement in recall.
EF_SEARCH = 256
# When performing hybrid search, we need to consider more candidates than the number of results to be returned.
# This is because the scoring is hybrid and the results are reordered due to the hybrid scoring.
# Higher = more candidates for hybrid fusion = better retrieval accuracy, but results in more computation per query.
# Imagine a simple case with a single keyword query and a single vector query and we want 10 final docs.
# If we only fetch 10 candidates from each of keyword and vector, they would have to have perfect overlap to get a good hybrid
# ranking for the 10 results. If we fetch 1000 candidates from each, we have a much higher chance of all 10 of the final desired
# docs showing up and getting scored. In worse situations, the final 10 docs don't even show up as the final 10 (worse than just
# a miss at the reranking step).
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = 750
# The default number of neighbors to consider for knn vector similarity search.
# We need this higher than the number of results because the scoring is hybrid.
# If there is only 1 query, setting k equal to the number of results is enough,
# but since there is heavy reordering due to hybrid scoring, we need to set k higher.
# Higher = more candidates for hybrid fusion = better retrieval accuracy, more query cost.
DEFAULT_K_NUM_CANDIDATES = 50 # TODO likely need to bump this way higher
# Number of vectors to examine for top k neighbors for the HNSW method.
EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
# Since the titles are included in the contents, they are heavily downweighted as they act as a boost
# rather than an independent scoring component.
SEARCH_TITLE_VECTOR_WEIGHT = 0.1
SEARCH_TITLE_KEYWORD_WEIGHT = 0.1
SEARCH_CONTENT_VECTOR_WEIGHT = 0.4
SEARCH_CONTENT_KEYWORD_WEIGHT = 0.4
SEARCH_CONTENT_VECTOR_WEIGHT = 0.45
# Single keyword weight for both title and content (merged from former title keyword + content keyword).
SEARCH_KEYWORD_WEIGHT = 0.45
# NOTE: it is critical that the order of these weights matches the order of the sub-queries in the hybrid search.
HYBRID_SEARCH_NORMALIZATION_WEIGHTS = [
SEARCH_TITLE_VECTOR_WEIGHT,
SEARCH_TITLE_KEYWORD_WEIGHT,
SEARCH_CONTENT_VECTOR_WEIGHT,
SEARCH_CONTENT_KEYWORD_WEIGHT,
SEARCH_KEYWORD_WEIGHT,
]
assert sum(HYBRID_SEARCH_NORMALIZATION_WEIGHTS) == 1.0

View File

@@ -842,6 +842,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
body=query_body,
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
)
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights

View File

@@ -11,6 +11,7 @@ from pydantic import model_serializer
from pydantic import model_validator
from pydantic import SerializerFunctionWrapHandler
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
@@ -54,6 +55,11 @@ SECONDARY_OWNERS_FIELD_NAME = "secondary_owners"
ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME = "ancestor_hierarchy_node_ids"
# Faiss was also tried but it didn't have any benefits
# NMSLIB is deprecated, not recommended
OPENSEARCH_KNN_ENGINE = "lucene"
def get_opensearch_doc_chunk_id(
tenant_state: TenantState,
document_id: str,
@@ -343,6 +349,9 @@ class DocumentSchema:
"properties": {
TITLE_FIELD_NAME: {
"type": "text",
# Language analyzer (e.g. english) stems at index and search time for variant matching.
# Configure via OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing after a change.
"analyzer": OPENSEARCH_TEXT_ANALYZER,
"fields": {
# Subfield accessed as title.keyword. Not indexed for
# values longer than 256 chars.
@@ -357,9 +366,7 @@ class DocumentSchema:
CONTENT_FIELD_NAME: {
"type": "text",
"store": True,
# This makes highlighting text during queries more efficient
# at the cost of disk space. See
# https://docs.opensearch.org/latest/search-plugins/searching-data/highlight/#methods-of-obtaining-offsets
"analyzer": OPENSEARCH_TEXT_ANALYZER,
"index_options": "offsets",
},
TITLE_VECTOR_FIELD_NAME: {
@@ -368,7 +375,7 @@ class DocumentSchema:
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": "lucene",
"engine": OPENSEARCH_KNN_ENGINE,
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
},
},
@@ -380,7 +387,7 @@ class DocumentSchema:
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": "lucene",
"engine": OPENSEARCH_KNN_ENGINE,
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
},
},

View File

@@ -6,13 +6,16 @@ from typing import Any
from uuid import UUID
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
from onyx.configs.app_configs import OPENSEARCH_PROFILING_DISABLED
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import Tag
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import DEFAULT_K_NUM_CANDIDATES
from onyx.document_index.opensearch.constants import (
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
)
from onyx.document_index.opensearch.constants import HYBRID_SEARCH_NORMALIZATION_WEIGHTS
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME
@@ -240,6 +243,9 @@ class DocumentQuery:
Returns:
A dictionary representing the final hybrid search query.
"""
# WARNING: Profiling does not work with hybrid search; do not add it at
# this level. See https://github.com/opensearch-project/neural-search/issues/1255
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
raise ValueError(
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
@@ -247,7 +253,7 @@ class DocumentQuery:
)
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
query_text, query_vector, num_candidates=DEFAULT_K_NUM_CANDIDATES
query_text, query_vector
)
hybrid_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
@@ -275,25 +281,31 @@ class DocumentQuery:
hybrid_search_query: dict[str, Any] = {
"hybrid": {
"queries": hybrid_search_subqueries,
# Applied to all the sub-queries. Source:
# Max results per subquery per shard before aggregation. Ensures keyword and vector
# subqueries contribute equally to the candidate pool for hybrid fusion.
# Sources:
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
# https://opensearch.org/blog/navigating-pagination-in-hybrid-queries-with-the-pagination_depth-parameter/
"pagination_depth": DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
# Applied to all the sub-queries independently (this avoids having subqueries having a lot of results thrown out).
# Sources:
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
# https://opensearch.org/blog/introducing-common-filter-support-for-hybrid-search-queries
# Does AND for each filter in the list.
"filter": {"bool": {"filter": hybrid_search_filters}},
}
}
# NOTE: By default, hybrid search retrieves "size"-many results from
# each OpenSearch shard before aggregation. Source:
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
final_hybrid_search_body: dict[str, Any] = {
"query": hybrid_search_query,
"size": num_hits,
"highlight": match_highlights_configuration,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
}
# WARNING: Profiling does not work with hybrid search; do not add it at
# this level. See https://github.com/opensearch-project/neural-search/issues/1255
# Explain is for scoring breakdowns.
if OPENSEARCH_EXPLAIN_ENABLED:
final_hybrid_search_body["explain"] = True
return final_hybrid_search_body
@@ -355,7 +367,12 @@ class DocumentQuery:
@staticmethod
def _get_hybrid_search_subqueries(
query_text: str, query_vector: list[float], num_candidates: int
query_text: str,
query_vector: list[float],
# The default number of neighbors to consider for knn vector similarity search.
# This is higher than the number of results because the scoring is hybrid.
# for a detailed breakdown, see where the default value is set.
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
) -> list[dict[str, Any]]:
"""Returns subqueries for hybrid search.
@@ -367,9 +384,8 @@ class DocumentQuery:
Matches:
- Title vector
- Title keyword
- Content vector
- Content keyword + phrase
- Keyword (title + content, match and phrase)
Normalization is not performed here.
The weights of each of these subqueries should be configured in a search
@@ -390,9 +406,9 @@ class DocumentQuery:
NOTE: Options considered and rejected:
- minimum_should_match: Since it's hybrid search and users often provide semantic queries, there is often a lot of terms,
and very low number of meaningful keywords (and a low ratio of keywords).
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). This is reasonable but in reality seeing the
user usage patterns, this is not very common and people tend to not be confused when a miss happens for this reason.
In testing datasets, this makes recall slightly worse.
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). It's mostly for typos as the analyzer ("english by
default") already does some stemming and tokenization. In testing datasets, this makes recall slightly worse. It also is
less performant so not really any reason to do it.
Args:
query_text: The text of the query to search for.
@@ -401,64 +417,56 @@ class DocumentQuery:
similarity search.
"""
# Build sub-queries for hybrid search. Order must match normalization
# pipeline weights: title vector, title keyword, content vector,
# content keyword.
# pipeline weights: title vector, content vector, keyword (title + content).
hybrid_search_queries: list[dict[str, Any]] = [
# 1. Title vector search
{
"knn": {
TITLE_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": num_candidates,
"k": vector_candidates,
}
}
},
# 2. Title keyword + phrase search.
{
"bool": {
"should": [
{
"match": {
TITLE_FIELD_NAME: {
"query": query_text,
# operator "or" = match doc if any query term matches (default, explicit for clarity).
"operator": "or",
}
}
},
{
"match_phrase": {
TITLE_FIELD_NAME: {
"query": query_text,
# Slop = 1 allows one extra word or transposition in phrase match.
"slop": 1,
# Boost phrase over bag-of-words; exact phrase is a stronger signal.
"boost": 1.5,
}
}
},
]
}
},
# 3. Content vector search
# 2. Content vector search
{
"knn": {
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": num_candidates,
"k": vector_candidates,
}
}
},
# 4. Content keyword + phrase search.
# 3. Keyword (title + content) match and phrase search.
{
"bool": {
"should": [
{
"match": {
TITLE_FIELD_NAME: {
"query": query_text,
"operator": "or",
# The title fields are strongly discounted as they are included in the content.
# It just acts as a minor boost
"boost": 0.1,
}
}
},
{
"match_phrase": {
TITLE_FIELD_NAME: {
"query": query_text,
"slop": 1,
"boost": 0.2,
}
}
},
{
"match": {
CONTENT_FIELD_NAME: {
"query": query_text,
# operator "or" = match doc if any query term matches (default, explicit for clarity).
"operator": "or",
"boost": 1.0,
}
}
},
@@ -466,9 +474,7 @@ class DocumentQuery:
"match_phrase": {
CONTENT_FIELD_NAME: {
"query": query_text,
# Slop = 1 allows one extra word or transposition in phrase match.
"slop": 1,
# Boost phrase over bag-of-words; exact phrase is a stronger signal.
"boost": 1.5,
}
}

View File

@@ -10,6 +10,12 @@ from typing import cast
import httpx
from retry import retry
from onyx.background.celery.tasks.opensearch_migration.constants import (
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
)
from onyx.background.celery.tasks.opensearch_migration.transformer import (
FIELDS_NEEDED_FOR_TRANSFORMATION,
)
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE
from onyx.context.search.models import IndexFilters
@@ -277,54 +283,139 @@ def get_chunks_via_visit_api(
def get_all_chunks_paginated(
index_name: str,
tenant_state: TenantState,
continuation_token: str | None = None,
page_size: int = 1_000,
) -> tuple[list[dict], str | None]:
continuation_token_map: dict[int, str | None],
page_size: int,
) -> tuple[list[dict], dict[int, str | None]]:
"""Gets all chunks in Vespa matching the filters, paginated.
Uses the Visit API with slicing. Each continuation token map entry is for a
different slice. The number of entries determines the number of slices.
Args:
index_name: The name of the Vespa index to visit.
tenant_state: The tenant state to filter by.
continuation_token: Token returned by Vespa representing a page offset.
None to start from the beginning. Defaults to None.
continuation_token_map: Map of slice ID to a token returned by Vespa
representing a page offset. None to start from the beginning of the
slice.
page_size: Best-effort batch size for the visit. Defaults to 1,000.
Returns:
Tuple of (list of chunk dicts, next continuation token or None). The
continuation token is None when the visit is complete.
"""
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
selection: str = f"{index_name}.large_chunk_reference_ids == null"
if MULTI_TENANT:
selection += f" and {index_name}.tenant_id=='{tenant_state.tenant_id}'"
def _get_all_chunks_paginated_for_slice(
index_name: str,
tenant_state: TenantState,
slice_id: int,
total_slices: int,
continuation_token: str | None,
page_size: int,
) -> tuple[list[dict], str | None]:
if continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN:
logger.debug(
f"Slice {slice_id} has finished visiting. Returning empty list and {FINISHED_VISITING_SLICE_CONTINUATION_TOKEN}."
)
return [], FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
params: dict[str, str | int | None] = {
"selection": selection,
"wantedDocumentCount": page_size,
"format.tensors": "short-value",
}
if continuation_token is not None:
params["continuation"] = continuation_token
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
try:
with get_vespa_http_client() as http_client:
response = http_client.get(url, params=params)
response.raise_for_status()
except httpx.HTTPError as e:
error_base = "Failed to get chunks in Vespa."
logger.exception(
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
selection: str = f"{index_name}.large_chunk_reference_ids == null"
if MULTI_TENANT:
selection += f" and {index_name}.tenant_id=='{tenant_state.tenant_id}'"
field_set = f"{index_name}:" + ",".join(FIELDS_NEEDED_FOR_TRANSFORMATION)
params: dict[str, str | int | None] = {
"selection": selection,
"fieldSet": field_set,
"wantedDocumentCount": page_size,
"format.tensors": "short-value",
"slices": total_slices,
"sliceId": slice_id,
}
if continuation_token is not None:
params["continuation"] = continuation_token
response: httpx.Response | None = None
try:
with get_vespa_http_client() as http_client:
response = http_client.get(url, params=params)
response.raise_for_status()
except httpx.HTTPError as e:
error_base = f"Failed to get chunks from Vespa slice {slice_id} with continuation token {continuation_token}."
logger.exception(
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
)
error_message = (
response.json().get("message") if response else "No response"
)
logger.error("Error message from response: %s", error_message)
raise httpx.HTTPError(error_base) from e
response_data = response.json()
# NOTE: If we see a falsey value for "continuation" in the response we
# assume we are done and return
# FINISHED_VISITING_SLICE_CONTINUATION_TOKEN instead.
next_continuation_token = (
response_data.get("continuation")
or FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
)
raise httpx.HTTPError(error_base) from e
chunks = [chunk["fields"] for chunk in response_data.get("documents", [])]
if next_continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN:
logger.debug(
f"Slice {slice_id} has finished visiting. Returning {len(chunks)} chunks and {next_continuation_token}."
)
return chunks, next_continuation_token
response_data = response.json()
total_slices = len(continuation_token_map)
if total_slices < 1:
raise ValueError("continuation_token_map must have at least one entry.")
# We want to guarantee that these invocations are ordered by slice_id,
# because we read in the same order below when parsing parallel_results.
functions_with_args: list[tuple[Callable, tuple]] = [
(
_get_all_chunks_paginated_for_slice,
(
index_name,
tenant_state,
slice_id,
total_slices,
continuation_token,
page_size,
),
)
for slice_id, continuation_token in sorted(continuation_token_map.items())
]
return [
chunk["fields"] for chunk in response_data.get("documents", [])
], response_data.get("continuation") or None
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=True
)
if len(parallel_results) != total_slices:
raise RuntimeError(
f"Expected {total_slices} parallel results, but got {len(parallel_results)}."
)
chunks: list[dict] = []
next_continuation_token_map: dict[int, str | None] = {
key: value for key, value in continuation_token_map.items()
}
for i, parallel_result in enumerate(parallel_results):
if i not in next_continuation_token_map:
raise RuntimeError(f"Slice {i} is not in the continuation token map.")
if parallel_result is None:
logger.error(
f"Failed to get chunks for slice {i} of {total_slices}. "
"The continuation token for this slice will not be updated."
)
continue
chunks.extend(parallel_result[0])
next_continuation_token_map[i] = parallel_result[1]
return chunks, next_continuation_token_map
# TODO(rkuo): candidate for removal if not being used

View File

@@ -56,6 +56,7 @@ from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import SEARCH_ENDPOINT
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
from onyx.document_index.vespa_constants import YQL_BASE
from onyx.indexing.models import DocMetadataAwareIndexChunk
@@ -553,10 +554,9 @@ class VespaDocumentIndex(DocumentIndex):
num_to_retrieve: int,
) -> list[InferenceChunk]:
vespa_where_clauses = build_vespa_filters(filters)
# Needs to be at least as much as the rerank-count value set in the
# Vespa schema config. Otherwise we would be getting fewer results than
# expected for reranking.
target_hits = max(10 * num_to_retrieve, RERANK_COUNT)
# Avoid over-fetching a very large candidate set for global-phase reranking.
# Keep enough headroom for quality while capping cost on larger indices.
target_hits = min(max(4 * num_to_retrieve, 100), RERANK_COUNT)
yql = (
YQL_BASE.format(index_name=self._index_name)
@@ -652,9 +652,9 @@ class VespaDocumentIndex(DocumentIndex):
def get_all_raw_document_chunks_paginated(
self,
continuation_token: str | None,
continuation_token_map: dict[int, str | None],
page_size: int,
) -> tuple[list[dict[str, Any]], str | None]:
) -> tuple[list[dict[str, Any]], dict[int, str | None]]:
"""Gets all the chunks in Vespa, paginated.
Used in the chunk-level Vespa-to-OpenSearch migration task.
@@ -662,21 +662,21 @@ class VespaDocumentIndex(DocumentIndex):
Args:
continuation_token: Token returned by Vespa representing a page
offset. None to start from the beginning. Defaults to None.
page_size: Best-effort batch size for the visit. Defaults to 1,000.
page_size: Best-effort batch size for the visit.
Returns:
Tuple of (list of chunk dicts, next continuation token or None). The
continuation token is None when the visit is complete.
"""
raw_chunks, next_continuation_token = get_all_chunks_paginated(
raw_chunks, next_continuation_token_map = get_all_chunks_paginated(
index_name=self._index_name,
tenant_state=TenantState(
tenant_id=self._tenant_id, multitenant=MULTI_TENANT
),
continuation_token=continuation_token,
continuation_token_map=continuation_token_map,
page_size=page_size,
)
return raw_chunks, next_continuation_token
return raw_chunks, next_continuation_token_map
def index_raw_chunks(self, chunks: list[dict[str, Any]]) -> None:
"""Indexes raw document chunks into Vespa.
@@ -702,3 +702,32 @@ class VespaDocumentIndex(DocumentIndex):
json={"fields": chunk},
)
response.raise_for_status()
def get_chunk_count(self) -> int:
"""Returns the exact number of document chunks in Vespa for this tenant.
Uses the Vespa Search API with `limit 0` and `ranking.profile=unranked`
to get an exact count without fetching any document data.
Includes large chunks. There is no way to filter these out using the
Search API.
"""
where_clause = (
f'tenant_id contains "{self._tenant_id}"' if self._multitenant else "true"
)
yql = (
f"select documentid from {self._index_name} "
f"where {where_clause} "
f"limit 0"
)
params: dict[str, str | int] = {
"yql": yql,
"ranking.profile": "unranked",
"timeout": VESPA_TIMEOUT,
}
with get_vespa_http_client() as http_client:
response = http_client.post(SEARCH_ENDPOINT, json=params)
response.raise_for_status()
response_data = response.json()
return response_data["root"]["fields"]["totalCount"]

View File

@@ -20,7 +20,20 @@ class ImageGenerationProviderCredentials(BaseModel):
custom_config: dict[str, str] | None = None
class ReferenceImage(BaseModel):
data: bytes
mime_type: str
class ImageGenerationProvider(abc.ABC):
@property
def supports_reference_images(self) -> bool:
return False
@property
def max_reference_images(self) -> int:
return 0
@classmethod
@abc.abstractmethod
def validate_credentials(
@@ -63,6 +76,7 @@ class ImageGenerationProvider(abc.ABC):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
"""Generates an image based on a prompt."""

View File

@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -59,6 +60,7 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation

View File

@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -45,6 +46,7 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import base64
import json
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
@@ -9,6 +11,7 @@ from pydantic import BaseModel
from onyx.image_gen.exceptions import ImageProviderCredentialsError
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -51,6 +54,15 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
vertex_credentials=vertex_credentials,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# Gemini image editing supports up to 14 input images.
return 14
def generate_image(
self,
prompt: str,
@@ -58,8 +70,18 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
if reference_images:
return self._generate_image_with_reference_images(
prompt=prompt,
model=model,
size=size,
n=n,
reference_images=reference_images,
)
from litellm import image_generation
return image_generation(
@@ -74,6 +96,99 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
**kwargs,
)
def _generate_image_with_reference_images(
self,
prompt: str,
model: str,
size: str,
n: int,
reference_images: list[ReferenceImage],
) -> ImageGenerationResponse:
from google import genai
from google.genai import types as genai_types
from google.oauth2 import service_account
from litellm.types.utils import ImageObject
from litellm.types.utils import ImageResponse
service_account_info = json.loads(self._vertex_credentials)
credentials = service_account.Credentials.from_service_account_info(
service_account_info,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
client = genai.Client(
vertexai=True,
project=self._vertex_project,
location=self._vertex_location,
credentials=credentials,
)
parts: list[genai_types.Part] = [
genai_types.Part.from_bytes(data=image.data, mime_type=image.mime_type)
for image in reference_images
]
parts.append(genai_types.Part.from_text(text=prompt))
config = genai_types.GenerateContentConfig(
response_modalities=["TEXT", "IMAGE"],
candidate_count=max(1, n),
image_config=genai_types.ImageConfig(
aspect_ratio=_map_size_to_aspect_ratio(size)
),
)
model_name = model.replace("vertex_ai/", "")
response = client.models.generate_content(
model=model_name,
contents=genai_types.Content(
role="user",
parts=parts,
),
config=config,
)
generated_data: list[ImageObject] = []
for candidate in response.candidates or []:
candidate_content = candidate.content
if not candidate_content:
continue
for part in candidate_content.parts or []:
inline_data = part.inline_data
if not inline_data or inline_data.data is None:
continue
if isinstance(inline_data.data, bytes):
b64_json = base64.b64encode(inline_data.data).decode("utf-8")
elif isinstance(inline_data.data, str):
b64_json = inline_data.data
else:
continue
generated_data.append(
ImageObject(
b64_json=b64_json,
revised_prompt=prompt,
)
)
if not generated_data:
raise RuntimeError("No image data returned from Vertex AI.")
return ImageResponse(
created=int(datetime.now().timestamp()),
data=generated_data,
)
def _map_size_to_aspect_ratio(size: str) -> str:
return {
"1024x1024": "1:1",
"1792x1024": "16:9",
"1024x1792": "9:16",
"1536x1024": "3:2",
"1024x1536": "2:3",
}.get(size, "1:1")
def _parse_to_vertex_credentials(
credentials: ImageGenerationProviderCredentials,

View File

@@ -64,21 +64,6 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"anthropic.claude-3-7-sonnet-20240620-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -159,11 +144,6 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"apac.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"apac.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -1320,11 +1300,6 @@
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"bedrock/us-gov-east-1/anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -1365,16 +1340,6 @@
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"bedrock/us-gov-west-1/anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"bedrock/us-gov-west-1/anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -1505,26 +1470,6 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-7-sonnet-20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"claude-3-7-sonnet-latest": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-7-sonnet@20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"claude-3-haiku-20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"claude-4-opus-20250514": {
"display_name": "Claude Opus 4",
"model_vendor": "anthropic",
@@ -1705,16 +1650,6 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"eu.anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"eu.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -3226,15 +3161,6 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
},
"openrouter/anthropic/claude-3-haiku-20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"openrouter/anthropic/claude-3-sonnet": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic"
@@ -3249,16 +3175,6 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3.7-sonnet": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3.7-sonnet:beta": {
"display_name": "Claude Sonnet 3.7:beta",
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-haiku-4.5": {
"display_name": "Claude Haiku 4.5",
"model_vendor": "anthropic",
@@ -3750,16 +3666,6 @@
"model_vendor": "anthropic",
"model_version": "20241022"
},
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"us.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"us.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -3879,20 +3785,6 @@
"model_vendor": "anthropic",
"model_version": "20240620"
},
"vertex_ai/claude-3-7-sonnet@20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"vertex_ai/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
},
"vertex_ai/claude-3-haiku@20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"vertex_ai/claude-3-sonnet": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic"

View File

@@ -1,5 +1,7 @@
import json
import pathlib
import threading
import time
from onyx.llm.constants import LlmProviderNames
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
@@ -23,6 +25,11 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_RECOMMENDATIONS_CACHE_TTL_SECONDS = 300
_recommendations_cache_lock = threading.Lock()
_cached_recommendations: LLMRecommendations | None = None
_cached_recommendations_time: float = 0.0
def _get_provider_to_models_map() -> dict[str, list[str]]:
"""Lazy-load provider model mappings to avoid importing litellm at module level.
@@ -41,19 +48,40 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
}
def get_recommendations() -> LLMRecommendations:
"""Get the recommendations from the GitHub config."""
recommendations_from_github = fetch_llm_recommendations_from_github()
if recommendations_from_github:
return recommendations_from_github
# Fall back to json bundled with code
def _load_bundled_recommendations() -> LLMRecommendations:
json_path = pathlib.Path(__file__).parent / "recommended-models.json"
with open(json_path, "r") as f:
json_config = json.load(f)
return LLMRecommendations.model_validate(json_config)
recommendations_from_json = LLMRecommendations.model_validate(json_config)
return recommendations_from_json
def get_recommendations() -> LLMRecommendations:
"""Get the recommendations, with an in-memory cache to avoid
hitting GitHub on every API request."""
global _cached_recommendations, _cached_recommendations_time
now = time.monotonic()
if (
_cached_recommendations is not None
and (now - _cached_recommendations_time) < _RECOMMENDATIONS_CACHE_TTL_SECONDS
):
return _cached_recommendations
with _recommendations_cache_lock:
# Double-check after acquiring lock
if (
_cached_recommendations is not None
and (time.monotonic() - _cached_recommendations_time)
< _RECOMMENDATIONS_CACHE_TTL_SECONDS
):
return _cached_recommendations
recommendations_from_github = fetch_llm_recommendations_from_github()
result = recommendations_from_github or _load_bundled_recommendations()
_cached_recommendations = result
_cached_recommendations_time = time.monotonic()
return result
def is_obsolete_model(model_name: str, provider: str) -> bool:
@@ -215,6 +243,23 @@ def model_configurations_for_provider(
) -> list[ModelConfigurationView]:
recommended_visible_models = llm_recommendations.get_visible_models(provider_name)
recommended_visible_models_names = [m.name for m in recommended_visible_models]
# Preserve provider-defined ordering while de-duplicating.
model_names: list[str] = []
seen_model_names: set[str] = set()
for model_name in (
fetch_models_for_provider(provider_name) + recommended_visible_models_names
):
if model_name in seen_model_names:
continue
seen_model_names.add(model_name)
model_names.append(model_name)
# Vertex model list can be large and mixed-vendor; alphabetical ordering
# makes model discovery easier in admin selection UIs.
if provider_name == VERTEXAI_PROVIDER_NAME:
model_names = sorted(model_names, key=str.lower)
return [
ModelConfigurationView(
name=model_name,
@@ -222,8 +267,7 @@ def model_configurations_for_provider(
max_input_tokens=get_max_input_tokens(model_name, provider_name),
supports_image_input=model_supports_image_input(model_name, provider_name),
)
for model_name in set(fetch_models_for_provider(provider_name))
| set(recommended_visible_models_names)
for model_name in model_names
]

View File

@@ -52,6 +52,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import POSTGRES_WEB_APP_NAME
from onyx.db.engine.async_sql_engine import get_sqlalchemy_async_engine
from onyx.db.engine.connection_warmup import warm_up_connections
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
@@ -63,7 +64,7 @@ from onyx.server.documents.connector import router as connector_router
from onyx.server.documents.credential import router as credential_router
from onyx.server.documents.document import router as document_router
from onyx.server.documents.standard_oauth import router as standard_oauth_router
from onyx.server.features.build.api.api import nextjs_assets_router
from onyx.server.features.build.api.api import public_build_router
from onyx.server.features.build.api.api import router as build_router
from onyx.server.features.default_assistant.api import (
router as default_assistant_router,
@@ -114,13 +115,16 @@ from onyx.server.manage.users import router as user_router
from onyx.server.manage.web_search.api import (
admin_router as web_search_admin_router,
)
from onyx.server.metrics.postgres_connection_pool import (
setup_postgres_connection_pool_metrics,
)
from onyx.server.metrics.prometheus_setup import setup_prometheus_metrics
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_auth_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.server.middleware.rate_limiting import setup_auth_limiter
from onyx.server.onyx_api.ingestion import router as onyx_api_router
from onyx.server.pat.api import router as pat_router
from onyx.server.prometheus_instrumentation import setup_prometheus_metrics
from onyx.server.query_and_chat.chat_backend import router as chat_router
from onyx.server.query_and_chat.query_backend import (
admin_router as admin_query_router,
@@ -138,6 +142,7 @@ from onyx.setup import setup_onyx
from onyx.tracing.setup import setup_tracing
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_endpoint_context_middleware
from onyx.utils.middleware import add_onyx_request_id_middleware
from onyx.utils.telemetry import get_or_generate_uuid
from onyx.utils.telemetry import optional_telemetry
@@ -266,6 +271,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
max_overflow=POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW,
)
# Register pool metrics now that engines are created.
# HTTP instrumentation is set up earlier in get_application() since it
# adds middleware (which Starlette forbids after the app has started).
setup_postgres_connection_pool_metrics(
engines={
"sync": SqlEngine.get_engine(),
"async": get_sqlalchemy_async_engine(),
"readonly": SqlEngine.get_readonly_engine(),
},
)
verify_auth = fetch_versioned_implementation(
"onyx.auth.users", "verify_auth_setting"
)
@@ -378,8 +394,8 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, cc_pair_router)
include_router_with_global_prefix_prepended(application, projects_router)
include_router_with_global_prefix_prepended(application, public_build_router)
include_router_with_global_prefix_prepended(application, build_router)
include_router_with_global_prefix_prepended(application, nextjs_assets_router)
include_router_with_global_prefix_prepended(application, document_set_router)
include_router_with_global_prefix_prepended(application, hierarchy_router)
include_router_with_global_prefix_prepended(application, search_settings_router)
@@ -560,12 +576,18 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
add_onyx_request_id_middleware(application, "API", logger)
# Set endpoint context for per-endpoint DB pool attribution metrics.
# Must be registered after all routes are added.
add_endpoint_context_middleware(application)
# HTTP request metrics (latency histograms, in-progress gauge, slow request
# counter). Must be called here — before the app starts — because the
# instrumentator adds middleware via app.add_middleware().
setup_prometheus_metrics(application)
# Ensure all routes have auth enabled or are explicitly marked as public
check_router_auth(application)
# Initialize and instrument the app with production Prometheus config
setup_prometheus_metrics(application)
use_route_function_names_as_operation_ids(application)
return application

View File

@@ -69,6 +69,12 @@ Very briefly describe the image(s) generated. Do not include any links or attach
""".strip()
FILE_REMINDER = """
Your code execution generated file(s) with download links.
If you reference or share these files, use the exact markdown format [filename](file_link) with the file_link from the execution result.
""".strip()
# Specifically for OpenAI models, this prefix needs to be in place for the model to output markdown and correct styling
CODE_BLOCK_MARKDOWN = "Formatting re-enabled. "

View File

@@ -1,6 +1,6 @@
# ruff: noqa: E501, W605 start
# If there are any tools, this section is included, the sections below are for the available tools
TOOL_SECTION_HEADER = "\n\n# Tools\n"
TOOL_SECTION_HEADER = "\n# Tools\n\n"
# This section is included if there are search type tools, currently internal_search and web_search
@@ -16,11 +16,10 @@ When searching for information, if the initial results cannot fully answer the u
Do not repeat the same or very similar queries if it already has been run in the chat history.
If it is unclear which tool to use, consider using multiple in parallel to be efficient with time.
"""
""".lstrip()
INTERNAL_SEARCH_GUIDANCE = """
## internal_search
Use the `internal_search` tool to search connected applications for information. Some examples of when to use `internal_search` include:
- Internal information: any time where there may be some information stored in internal applications that could help better answer the query.
@@ -28,34 +27,31 @@ Use the `internal_search` tool to search connected applications for information.
- Keyword Queries: queries that are heavily keyword based are often internal document search queries.
- Ambiguity: questions about something that is not widely known or understood.
Never provide more than 3 queries at once to `internal_search`.
"""
""".lstrip()
WEB_SEARCH_GUIDANCE = """
## web_search
Use the `web_search` tool to access up-to-date information from the web. Some examples of when to use `web_search` include:
- Freshness: when the answer might be enhanced by up-to-date information on a topic. Very important for topics that are changing or evolving.
- Accuracy: if the cost of outdated/inaccurate information is high.
- Niche Information: when detailed info is not widely known or understood (but is likely found on the internet).{site_colon_disabled}
"""
""".lstrip()
WEB_SEARCH_SITE_DISABLED_GUIDANCE = """
Do not use the "site:" operator in your web search queries.
""".rstrip()
""".lstrip()
OPEN_URLS_GUIDANCE = """
## open_url
Use the `open_url` tool to read the content of one or more URLs. Use this tool to access the contents of the most promising web pages from your web searches or user specified URLs. \
You can open many URLs at once by passing multiple URLs in the array if multiple pages seem promising. Prioritize the most promising pages and reputable sources. \
Do not open URLs that are image files like .png, .jpg, etc.
You should almost always use open_url after a web_search call. Use this tool when a user asks about a specific provided URL.
"""
""".lstrip()
PYTHON_TOOL_GUIDANCE = """
## python
Use the `python` tool to execute Python code in an isolated sandbox. The tool will respond with the output of the execution or time out after 60.0 seconds.
Any files uploaded to the chat will be automatically be available in the execution environment's current directory. \
@@ -64,21 +60,21 @@ Use this to give the user a way to download the file OR to display generated ima
Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
Use `openpyxl` to read and write Excel files. You have access to libraries like numpy, pandas, scipy, matplotlib, and PIL.
IMPORTANT: each call to this tool is independent. Variables from previous calls will NOT be available in the current call.
"""
""".lstrip()
GENERATE_IMAGE_GUIDANCE = """
## generate_image
NEVER use generate_image unless the user specifically requests an image.
"""
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
the `file_id` values returned by earlier `generate_image` tool results.
""".lstrip()
MEMORY_GUIDANCE = """
## add_memory
Use the `add_memory` tool for facts shared by the user that should be remembered for future conversations. \
Only add memories that are specific, likely to remain true, and likely to be useful later. \
Focus on enduring preferences, long-term goals, stable constraints, and explicit "remember this" type requests.
"""
""".lstrip()
TOOL_CALL_FAILURE_PROMPT = """
LLM attempted to call a tool but failed. Most likely the tool name or arguments were misspelled.

View File

@@ -1,40 +1,36 @@
# ruff: noqa: E501, W605 start
USER_INFORMATION_HEADER = "\n\n# User Information\n"
USER_INFORMATION_HEADER = "\n# User Information\n\n"
BASIC_INFORMATION_PROMPT = """
## Basic Information
User name: {user_name}
User email: {user_email}{user_role}
"""
""".lstrip()
# This line only shows up if the user has configured their role.
USER_ROLE_PROMPT = """
User role: {user_role}
"""
""".lstrip()
# Team information should be a paragraph style description of the user's team.
TEAM_INFORMATION_PROMPT = """
## Team Information
{team_information}
"""
""".lstrip()
# User preferences should be a paragraph style description of the user's preferences.
USER_PREFERENCES_PROMPT = """
## User Preferences
{user_preferences}
"""
""".lstrip()
# User memories should look something like:
# - Memory 1
# - Memory 2
# - Memory 3
USER_MEMORIES_PROMPT = """
## User Memories
{user_memories}
"""
""".lstrip()
# ruff: noqa: E501, W605 end

View File

@@ -109,6 +109,7 @@ class TenantRedis(redis.Redis):
"unlock",
"get",
"set",
"setex",
"delete",
"exists",
"incrby",

View File

@@ -59,6 +59,9 @@ PUBLIC_ENDPOINT_SPECS = [
# anonymous user on cloud
("/tenants/anonymous-user", {"POST"}),
("/metrics", {"GET"}), # added by prometheus_fastapi_instrumentator
# craft webapp proxy — access enforced per-session via sharing_scope in handler
("/build/sessions/{session_id}/webapp", {"GET"}),
("/build/sessions/{session_id}/webapp/{path:path}", {"GET"}),
]

View File

@@ -103,6 +103,7 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingMode
from onyx.db.enums import ProcessingMode
from onyx.db.federated import fetch_all_federated_connectors_parallel
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempts_by_status
@@ -987,6 +988,7 @@ def get_connector_status(
user=user,
eager_load_connector=True,
eager_load_credential=True,
eager_load_user=True,
get_editable=False,
)
@@ -1000,11 +1002,23 @@ def get_connector_status(
relationship.user_group_id
)
# Pre-compute credential_ids per connector to avoid N+1 lazy loads
connector_to_credential_ids: dict[int, list[int]] = {}
for cc_pair in cc_pairs:
connector_to_credential_ids.setdefault(cc_pair.connector_id, []).append(
cc_pair.credential_id
)
return [
ConnectorStatus(
cc_pair_id=cc_pair.id,
name=cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(cc_pair.connector),
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair.connector,
credential_ids=connector_to_credential_ids.get(
cc_pair.connector_id, []
),
),
credential=CredentialSnapshot.from_credential_db_model(cc_pair.credential),
access_type=cc_pair.access_type,
groups=group_cc_pair_relationships_dict.get(cc_pair.id, []),
@@ -1059,15 +1073,27 @@ def get_connector_indexing_status(
parallel_functions: list[tuple[CallableProtocol, tuple[Any, ...]]] = [
# Get editable connector/credential pairs
(
get_connector_credential_pairs_for_user_parallel,
(user, True, None, True, True, True, True, request.source),
lambda: get_connector_credential_pairs_for_user_parallel(
user, True, None, True, True, False, True, request.source
),
(),
),
# Get federated connectors
(fetch_all_federated_connectors_parallel, ()),
# Get most recent index attempts
(get_latest_index_attempts_parallel, (request.secondary_index, True, False)),
(
lambda: get_latest_index_attempts_parallel(
request.secondary_index, True, False
),
(),
),
# Get most recent finished index attempts
(get_latest_index_attempts_parallel, (request.secondary_index, True, True)),
(
lambda: get_latest_index_attempts_parallel(
request.secondary_index, True, True
),
(),
),
]
if user and user.role == UserRole.ADMIN:
@@ -1084,8 +1110,10 @@ def get_connector_indexing_status(
parallel_functions.append(
# Get non-editable connector/credential pairs
(
get_connector_credential_pairs_for_user_parallel,
(user, False, None, True, True, True, True, request.source),
lambda: get_connector_credential_pairs_for_user_parallel(
user, False, None, True, True, False, True, request.source
),
(),
),
)
@@ -1911,6 +1939,7 @@ Tenant ID: {tenant_id}
class BasicCCPairInfo(BaseModel):
has_successful_run: bool
source: DocumentSource
status: ConnectorCredentialPairStatus
@router.get("/connector-status", tags=PUBLIC_API_TAGS)
@@ -1924,13 +1953,17 @@ def get_basic_connector_indexing_status(
get_editable=False,
user=user,
)
# NOTE: This endpoint excludes Craft connectors
return [
BasicCCPairInfo(
has_successful_run=cc_pair.last_successful_index_time is not None,
source=cc_pair.connector.source,
status=cc_pair.status,
)
for cc_pair in cc_pairs
if cc_pair.connector.source != DocumentSource.INGESTION_API
and cc_pair.processing_mode == ProcessingMode.REGULAR
]

View File

@@ -365,7 +365,8 @@ class CCPairFullInfo(BaseModel):
in_repeated_error_state=cc_pair_model.in_repeated_error_state,
num_docs_indexed=num_docs_indexed,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_model.connector
cc_pair_model.connector,
credential_ids=[cc_pair_model.credential_id],
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_model.credential

View File

@@ -1,4 +1,5 @@
from collections.abc import Iterator
from pathlib import Path
from uuid import UUID
import httpx
@@ -7,16 +8,19 @@ from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from fastapi.responses import RedirectResponse
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.auth.users import optional_user
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import ProcessingMode
from onyx.db.enums import SharingScope
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.models import BuildSession
from onyx.db.models import User
@@ -217,12 +221,15 @@ def get_build_connectors(
return BuildConnectorListResponse(connectors=connectors)
# Headers to skip when proxying (hop-by-hop headers)
# Headers to skip when proxying.
# Hop-by-hop headers must not be forwarded, and set-cookie is stripped to
# prevent LLM-generated apps from setting cookies on the parent Onyx domain.
EXCLUDED_HEADERS = {
"content-encoding",
"content-length",
"transfer-encoding",
"connection",
"set-cookie",
}
@@ -280,7 +287,7 @@ def _get_sandbox_url(session_id: UUID, db_session: Session) -> str:
db_session: Database session
Returns:
The internal URL to proxy requests to
Internal URL to proxy requests to
Raises:
HTTPException: If session not found, port not allocated, or sandbox not found
@@ -294,12 +301,10 @@ def _get_sandbox_url(session_id: UUID, db_session: Session) -> str:
if session.user_id is None:
raise HTTPException(status_code=404, detail="User not found")
# Get the user's sandbox to get the sandbox_id
sandbox = get_sandbox_by_user_id(db_session, session.user_id)
if sandbox is None:
raise HTTPException(status_code=404, detail="Sandbox not found")
# Use sandbox manager to get the correct internal URL
sandbox_manager = get_sandbox_manager()
return sandbox_manager.get_webapp_url(sandbox.id, session.nextjs_port)
@@ -365,71 +370,73 @@ def _proxy_request(
raise HTTPException(status_code=502, detail="Bad gateway")
@router.get("/sessions/{session_id}/webapp", response_model=None)
def get_webapp_root(
def _check_webapp_access(
session_id: UUID, user: User | None, db_session: Session
) -> BuildSession:
"""Check if user can access a session's webapp.
- public_global: accessible by anyone (no auth required)
- public_org: accessible by any authenticated user
- private: only accessible by the session owner
"""
session = db_session.get(BuildSession, session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.sharing_scope == SharingScope.PUBLIC_GLOBAL:
return session
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
if session.sharing_scope == SharingScope.PRIVATE and session.user_id != user.id:
raise HTTPException(status_code=404, detail="Session not found")
return session
_OFFLINE_HTML_PATH = Path(__file__).parent / "templates" / "webapp_offline.html"
def _offline_html_response() -> Response:
"""Return a branded Craft HTML page when the sandbox is not reachable.
Design mirrors the default Craft web template (outputs/web/app/page.tsx):
terminal window aesthetic with Minecraft-themed typing animation.
"""
html = _OFFLINE_HTML_PATH.read_text()
return Response(content=html, status_code=503, media_type="text/html")
# Public router for webapp proxy — no authentication required
# (access controlled per-session via sharing_scope)
public_build_router = APIRouter(prefix="/build")
@public_build_router.get("/sessions/{session_id}/webapp", response_model=None)
@public_build_router.get(
"/sessions/{session_id}/webapp/{path:path}", response_model=None
)
def get_webapp(
session_id: UUID,
request: Request,
_: User = Depends(current_user),
path: str = "",
user: User | None = Depends(optional_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy the root path of the webapp for a specific session."""
return _proxy_request("", request, session_id, db_session)
"""Proxy the webapp for a specific session (root and subpaths).
@router.get("/sessions/{session_id}/webapp/{path:path}", response_model=None)
def get_webapp_path(
session_id: UUID,
path: str,
request: Request,
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy any subpath of the webapp (static assets, etc.) for a specific session."""
return _proxy_request(path, request, session_id, db_session)
# Separate router for Next.js static assets at /_next/*
# This is needed because Next.js apps may reference assets with root-relative paths
# that don't get rewritten. The session_id is extracted from the Referer header.
nextjs_assets_router = APIRouter()
def _extract_session_from_referer(request: Request) -> UUID | None:
"""Extract session_id from the Referer header.
Expects Referer to contain /api/build/sessions/{session_id}/webapp
Accessible without authentication when sharing_scope is public_global.
Returns a friendly offline page when the sandbox is not running.
"""
import re
referer = request.headers.get("referer", "")
match = re.search(r"/api/build/sessions/([a-f0-9-]+)/webapp", referer)
if match:
try:
return UUID(match.group(1))
except ValueError:
return None
return None
@nextjs_assets_router.get("/_next/{path:path}", response_model=None)
def get_nextjs_assets(
path: str,
request: Request,
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy Next.js static assets requested at root /_next/ path.
The session_id is extracted from the Referer header since these requests
come from within the iframe context.
"""
session_id = _extract_session_from_referer(request)
if not session_id:
raise HTTPException(
status_code=400,
detail="Could not determine session from request context",
)
return _proxy_request(f"_next/{path}", request, session_id, db_session)
try:
_check_webapp_access(session_id, user, db_session)
except HTTPException as e:
if e.status_code == 401:
return RedirectResponse(url="/auth/login", status_code=302)
raise
try:
return _proxy_request(path, request, session_id, db_session)
except HTTPException as e:
if e.status_code in (502, 503, 504):
return _offline_html_response()
raise
# =============================================================================

View File

@@ -10,6 +10,7 @@ from onyx.configs.constants import MessageType
from onyx.db.enums import ArtifactType
from onyx.db.enums import BuildSessionStatus
from onyx.db.enums import SandboxStatus
from onyx.db.enums import SharingScope
from onyx.server.features.build.sandbox.models import (
FilesystemEntry as FileSystemEntry,
)
@@ -107,6 +108,7 @@ class SessionResponse(BaseModel):
nextjs_port: int | None
sandbox: SandboxResponse | None
artifacts: list[ArtifactResponse]
sharing_scope: SharingScope
@classmethod
def from_model(
@@ -129,6 +131,7 @@ class SessionResponse(BaseModel):
nextjs_port=session.nextjs_port,
sandbox=(SandboxResponse.from_model(sandbox) if sandbox else None),
artifacts=[ArtifactResponse.from_model(a) for a in session.artifacts],
sharing_scope=session.sharing_scope,
)
@@ -159,6 +162,19 @@ class SessionListResponse(BaseModel):
sessions: list[SessionResponse]
class SetSessionSharingRequest(BaseModel):
"""Request to set the sharing scope of a session."""
sharing_scope: SharingScope
class SetSessionSharingResponse(BaseModel):
"""Response after setting session sharing scope."""
session_id: str
sharing_scope: SharingScope
# ===== Message Models =====
class MessageRequest(BaseModel):
"""Request to send a message to the CLI agent."""
@@ -244,6 +260,7 @@ class WebappInfo(BaseModel):
webapp_url: str | None # URL to access the webapp (e.g., http://localhost:3015)
status: str # Sandbox status (running, terminated, etc.)
ready: bool # Whether the NextJS dev server is actually responding
sharing_scope: SharingScope
# ===== File Upload Models =====

View File

@@ -30,6 +30,8 @@ from onyx.server.features.build.api.models import SessionListResponse
from onyx.server.features.build.api.models import SessionNameGenerateResponse
from onyx.server.features.build.api.models import SessionResponse
from onyx.server.features.build.api.models import SessionUpdateRequest
from onyx.server.features.build.api.models import SetSessionSharingRequest
from onyx.server.features.build.api.models import SetSessionSharingResponse
from onyx.server.features.build.api.models import SuggestionBubble
from onyx.server.features.build.api.models import SuggestionTheme
from onyx.server.features.build.api.models import UploadResponse
@@ -38,6 +40,7 @@ from onyx.server.features.build.configs import SANDBOX_BACKEND
from onyx.server.features.build.configs import SandboxBackend
from onyx.server.features.build.db.build_session import allocate_nextjs_port
from onyx.server.features.build.db.build_session import get_build_session
from onyx.server.features.build.db.build_session import set_build_session_sharing_scope
from onyx.server.features.build.db.sandbox import get_latest_snapshot_for_session
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
from onyx.server.features.build.db.sandbox import update_sandbox_heartbeat
@@ -294,6 +297,25 @@ def update_session_name(
return SessionResponse.from_model(session, sandbox)
@router.patch("/{session_id}/public")
def set_session_public(
session_id: UUID,
request: SetSessionSharingRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SetSessionSharingResponse:
"""Set the sharing scope of a build session's webapp."""
updated = set_build_session_sharing_scope(
session_id, user.id, request.sharing_scope, db_session
)
if not updated:
raise HTTPException(status_code=404, detail="Session not found")
return SetSessionSharingResponse(
session_id=str(session_id),
sharing_scope=updated.sharing_scope,
)
@router.delete("/{session_id}", response_model=None)
def delete_session(
session_id: UUID,

View File

@@ -0,0 +1,110 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta http-equiv="refresh" content="15" />
<title>Craft — Starting up</title>
<style>
*,
*::before,
*::after {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: ui-monospace, SFMono-Regular, "SF Mono", Menlo, Consolas,
monospace;
background: linear-gradient(to bottom right, #030712, #111827, #030712);
min-height: 100vh;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
gap: 1.5rem;
padding: 2rem;
}
.terminal {
width: 100%;
max-width: 580px;
border: 2px solid #374151;
border-radius: 2px;
}
.titlebar {
background: #1f2937;
padding: 0.5rem 0.75rem;
display: flex;
align-items: center;
gap: 0.5rem;
border-bottom: 1px solid #374151;
}
.btn {
width: 12px;
height: 12px;
border-radius: 2px;
flex-shrink: 0;
}
.btn-red {
background: #ef4444;
}
.btn-yellow {
background: #eab308;
}
.btn-green {
background: #22c55e;
}
.title-label {
flex: 1;
text-align: center;
font-size: 0.75rem;
color: #6b7280;
margin-right: 36px;
}
.body {
background: #111827;
padding: 1.5rem;
min-height: 200px;
font-size: 0.875rem;
color: #d1d5db;
display: flex;
align-items: flex-start;
gap: 0.375rem;
}
.prompt {
color: #10b981;
user-select: none;
}
.tagline {
font-size: 0.8125rem;
color: #4b5563;
text-align: center;
}
</style>
</head>
<body>
<div class="terminal">
<div class="titlebar">
<div class="btn btn-red"></div>
<div class="btn btn-yellow"></div>
<div class="btn btn-green"></div>
<span class="title-label">crafting_table</span>
</div>
<div class="body">
<span class="prompt">/&gt;</span>
<span>Sandbox is asleep...</span>
</div>
</div>
<p class="tagline">
Ask the owner to open their Craft session to wake it up.
</p>
</body>
</html>

View File

@@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import MessageType
from onyx.db.enums import BuildSessionStatus
from onyx.db.enums import SandboxStatus
from onyx.db.enums import SharingScope
from onyx.db.models import Artifact
from onyx.db.models import BuildMessage
from onyx.db.models import BuildSession
@@ -159,6 +160,26 @@ def update_session_status(
logger.info(f"Updated build session {session_id} status to {status}")
def set_build_session_sharing_scope(
session_id: UUID,
user_id: UUID,
sharing_scope: SharingScope,
db_session: Session,
) -> BuildSession | None:
"""Set the sharing scope of a build session.
Only the session owner can change this setting.
Returns the updated session, or None if not found/unauthorized.
"""
session = get_build_session(session_id, user_id, db_session)
if not session:
return None
session.sharing_scope = sharing_scope
db_session.commit()
logger.info(f"Set build session {session_id} sharing_scope={sharing_scope}")
return session
def delete_build_session__no_commit(
session_id: UUID,
user_id: UUID,

View File

@@ -474,6 +474,23 @@ class SandboxManager(ABC):
"""
...
def ensure_nextjs_running(
self,
sandbox_id: UUID,
session_id: UUID,
nextjs_port: int,
) -> None:
"""Ensure the Next.js server is running for a session.
Default is a no-op — only meaningful for local backends that manage
process lifecycles directly (e.g., LocalSandboxManager).
Args:
sandbox_id: The sandbox ID
session_id: The session ID
nextjs_port: The port the Next.js server should be listening on
"""
# Singleton instance cache for the factory
_sandbox_manager_instance: SandboxManager | None = None

View File

@@ -15,6 +15,8 @@ from collections.abc import Generator
from pathlib import Path
from uuid import UUID
import httpx
from onyx.db.enums import SandboxStatus
from onyx.file_store.file_store import get_default_file_store
from onyx.server.features.build.configs import DEMO_DATA_PATH
@@ -35,6 +37,7 @@ from onyx.server.features.build.sandbox.models import LLMProviderConfig
from onyx.server.features.build.sandbox.models import SandboxInfo
from onyx.server.features.build.sandbox.models import SnapshotResult
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import ThreadSafeSet
logger = setup_logger()
@@ -89,9 +92,17 @@ class LocalSandboxManager(SandboxManager):
self._acp_clients: dict[tuple[UUID, UUID], ACPAgentClient] = {}
# Track Next.js processes - keyed by (sandbox_id, session_id) tuple
# Used for clean shutdown when sessions are deleted
# Used for clean shutdown when sessions are deleted.
# Mutated from background threads; all access must hold _nextjs_lock.
self._nextjs_processes: dict[tuple[UUID, UUID], subprocess.Popen[bytes]] = {}
# Track sessions currently being (re)started - prevents concurrent restarts.
# ThreadSafeSet allows atomic check-and-add without holding _nextjs_lock.
self._nextjs_starting: ThreadSafeSet[tuple[UUID, UUID]] = ThreadSafeSet()
# Lock guarding _nextjs_processes (shared across sessions; hold briefly only)
self._nextjs_lock = threading.Lock()
# Validate templates exist (raises RuntimeError if missing)
self._validate_templates()
@@ -326,16 +337,18 @@ class LocalSandboxManager(SandboxManager):
RuntimeError: If termination fails
"""
# Stop all Next.js processes for this sandbox (keyed by (sandbox_id, session_id))
processes_to_stop = [
(key, process)
for key, process in self._nextjs_processes.items()
if key[0] == sandbox_id
]
with self._nextjs_lock:
processes_to_stop = [
(key, process)
for key, process in self._nextjs_processes.items()
if key[0] == sandbox_id
]
for key, process in processes_to_stop:
session_id = key[1]
try:
self._stop_nextjs_process(process, session_id)
del self._nextjs_processes[key]
with self._nextjs_lock:
self._nextjs_processes.pop(key, None)
except Exception as e:
logger.warning(
f"Failed to stop Next.js for sandbox {sandbox_id}, "
@@ -516,7 +529,8 @@ class LocalSandboxManager(SandboxManager):
web_dir, nextjs_port
)
# Store process for clean shutdown on session delete
self._nextjs_processes[(sandbox_id, session_id)] = nextjs_process
with self._nextjs_lock:
self._nextjs_processes[(sandbox_id, session_id)] = nextjs_process
logger.info("Next.js server started successfully")
# Setup venv and AGENTS.md
@@ -575,7 +589,8 @@ class LocalSandboxManager(SandboxManager):
"""
# Stop Next.js dev server - try stored process first, then fallback to port lookup
process_key = (sandbox_id, session_id)
nextjs_process = self._nextjs_processes.pop(process_key, None)
with self._nextjs_lock:
nextjs_process = self._nextjs_processes.pop(process_key, None)
if nextjs_process is not None:
self._stop_nextjs_process(nextjs_process, session_id)
elif nextjs_port is not None:
@@ -766,6 +781,85 @@ class LocalSandboxManager(SandboxManager):
outputs_path = session_path / "outputs"
return outputs_path.exists()
def ensure_nextjs_running(
self,
sandbox_id: UUID,
session_id: UUID,
nextjs_port: int,
) -> None:
"""Start Next.js server for a session if not already running.
Called when the server is detected as unreachable (e.g., after API server restart).
Returns immediately — the actual startup runs in a background daemon thread.
A per-session guard prevents concurrent restarts from racing.
Lock design: _nextjs_lock is shared across ALL sessions. Holding it during
httpx (1s) or start_nextjs_server (several seconds) would block every other
session's status checks and restarts. We only hold the lock for fast
in-memory ops (dict get, check_and_add). The slow I/O runs in the background
thread without holding any lock.
Args:
sandbox_id: The sandbox ID
session_id: The session ID
nextjs_port: The port number for the Next.js server
"""
process_key = (sandbox_id, session_id)
with self._nextjs_lock:
existing = self._nextjs_processes.get(process_key)
if existing is not None and existing.poll() is None:
return
# Atomic check-and-add: returns True if already in set (another thread is starting)
if self._nextjs_starting.check_and_add(process_key):
return
def _start_in_background() -> None:
try:
# Port check in background to avoid blocking the main thread
try:
with httpx.Client(timeout=1.0) as client:
client.get(f"http://localhost:{nextjs_port}")
logger.info(
f"Port {nextjs_port} already alive for session {session_id} "
"(orphan process) — skipping restart"
)
return
except Exception:
pass # Port is dead; proceed with restart
logger.info(
f"Starting Next.js for session {session_id} on port {nextjs_port}"
)
sandbox_path = self._get_sandbox_path(sandbox_id)
web_dir = self._directory_manager.get_web_path(
sandbox_path, str(session_id)
)
if not web_dir.exists():
logger.warning(
f"Web dir missing for session {session_id}: {web_dir}"
"cannot restart Next.js"
)
return
process = self._process_manager.start_nextjs_server(
web_dir, nextjs_port
)
with self._nextjs_lock:
self._nextjs_processes[process_key] = process
logger.info(
f"Auto-restarted Next.js for session {session_id} "
f"on port {nextjs_port}"
)
except Exception as e:
logger.error(
f"Failed to auto-restart Next.js for session {session_id}: {e}"
)
finally:
self._nextjs_starting.discard(process_key)
threading.Thread(target=_start_in_background, daemon=True).start()
def restore_snapshot(
self,
sandbox_id: UUID,

View File

@@ -1,10 +0,0 @@
"""Celery tasks for sandbox management."""
from onyx.server.features.build.sandbox.tasks.tasks import (
cleanup_idle_sandboxes_task,
) # noqa: F401
from onyx.server.features.build.sandbox.tasks.tasks import (
sync_sandbox_files,
) # noqa: F401
__all__ = ["cleanup_idle_sandboxes_task", "sync_sandbox_files"]

View File

@@ -1765,6 +1765,7 @@ class SessionManager:
"webapp_url": None,
"status": "no_sandbox",
"ready": False,
"sharing_scope": session.sharing_scope,
}
# Return the proxy URL - the proxy handles routing to the correct sandbox
@@ -1777,11 +1778,21 @@ class SessionManager:
# Quick health check: can the API server reach the NextJS dev server?
ready = self._check_nextjs_ready(sandbox.id, session.nextjs_port)
# If not ready, ask the sandbox manager to ensure Next.js is running.
# For the local backend this triggers a background restart so that the
# frontend poll loop eventually sees ready=True without the user having
# to manually recreate the session.
if not ready:
self._sandbox_manager.ensure_nextjs_running(
sandbox.id, session_id, session.nextjs_port
)
return {
"has_webapp": session.nextjs_port is not None,
"webapp_url": webapp_url,
"status": sandbox.status.value,
"ready": ready,
"sharing_scope": session.sharing_scope,
}
def _check_nextjs_ready(self, sandbox_id: UUID, port: int) -> bool:

View File

@@ -111,7 +111,8 @@ class DocumentSet(BaseModel):
id=cc_pair.id,
name=cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair.connector
cc_pair.connector,
credential_ids=[cc_pair.credential_id],
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair.credential

View File

@@ -26,13 +26,17 @@ def get_opensearch_migration_status(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> OpenSearchMigrationStatusResponse:
total_chunks_migrated, created_at, migration_completed_at = (
get_opensearch_migration_state(db_session)
)
(
total_chunks_migrated,
created_at,
migration_completed_at,
approx_chunk_count_in_vespa,
) = get_opensearch_migration_state(db_session)
return OpenSearchMigrationStatusResponse(
total_chunks_migrated=total_chunks_migrated,
created_at=created_at,
migration_completed_at=migration_completed_at,
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
)

Some files were not shown because too many files have changed in this diff Show More