Compare commits

..

1 Commits

Author SHA1 Message Date
Nik
48e7428069 chore(helm): remove broken code-interpreter dependency
The code-interpreter Helm chart repo at
https://onyx-dot-app.github.io/code-interpreter/ returns 404,
causing ct lint to fail in CI. Remove it from Chart.yaml
dependencies, Chart.lock, ct.yaml chart-repos, and the CI
workflow's helm repo add step.
2026-02-19 20:17:14 -08:00
390 changed files with 6818 additions and 14589 deletions

View File

@@ -8,5 +8,5 @@
## Additional Options
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
- [ ] [Optional] Override Linear Check

View File

@@ -33,7 +33,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Build chart dependencies

View File

@@ -1,161 +0,0 @@
name: Post-Merge Beta Cherry-Pick
on:
push:
branches:
- main
permissions:
contents: write
pull-requests: write
jobs:
cherry-pick-to-latest-release:
outputs:
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
pr_number: ${{ steps.gate.outputs.pr_number }}
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Resolve merged PR and checkbox state
id: gate
env:
GH_TOKEN: ${{ github.token }}
run: |
# For the commit that triggered this workflow (HEAD on main), fetch all
# associated PRs and keep only the PR that was actually merged into main
# with this exact merge commit SHA.
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
if [ "${match_count}" -gt 1 ]; then
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
fi
if [ -z "$pr_number" ]; then
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
exit 0
fi
# Read the PR once so we can gate behavior and infer preferred actor.
pr_json="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}")"
pr_body="$(printf '%s' "$pr_json" | jq -r '.body // ""')"
merged_by="$(printf '%s' "$pr_json" | jq -r '.merged_by.login // ""')"
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
echo "merged_by=$merged_by" >> "$GITHUB_OUTPUT"
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox checked for PR #${pr_number}."
exit 0
fi
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
- name: Checkout repository
if: steps.gate.outputs.should_cherrypick == 'true'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: true
ref: main
- name: Install the latest version of uv
if: steps.gate.outputs.should_cherrypick == 'true'
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Configure git identity
if: steps.gate.outputs.should_cherrypick == 'true'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Create cherry-pick PR to latest release
id: run_cherry_pick
if: steps.gate.outputs.should_cherrypick == 'true'
continue-on-error: true
env:
GH_TOKEN: ${{ github.token }}
GITHUB_TOKEN: ${{ github.token }}
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
run: |
set -o pipefail
output_file="$(mktemp)"
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
exit_code="${PIPESTATUS[0]}"
if [ "${exit_code}" -eq 0 ]; then
echo "status=success" >> "$GITHUB_OUTPUT"
exit 0
fi
echo "status=failure" >> "$GITHUB_OUTPUT"
reason="command-failed"
if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then
reason="merge-conflict"
fi
echo "reason=${reason}" >> "$GITHUB_OUTPUT"
{
echo "details<<EOF"
tail -n 40 "$output_file"
echo "EOF"
} >> "$GITHUB_OUTPUT"
- name: Mark workflow as failed if cherry-pick failed
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
run: |
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
exit 1
notify-slack-on-cherry-pick-failure:
needs:
- cherry-pick-to-latest-release
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
runs-on: ubuntu-slim
timeout-minutes: 10
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build cherry-pick failure summary
id: failure-summary
env:
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
run: |
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
reason_text="cherry-pick command failed"
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
reason_text="merge conflict during cherry-pick"
fi
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
if [ -n "${details_excerpt}" ]; then
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
fi
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
- name: Notify #cherry-pick-prs about cherry-pick failure
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
title: "🚨 Automated Cherry-Pick Failed"
ref-name: ${{ github.ref_name }}

View File

@@ -0,0 +1,28 @@
name: Require beta cherry-pick consideration
concurrency:
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
pull_request:
types: [opened, edited, reopened, synchronize]
permissions:
contents: read
jobs:
beta-cherrypick-check:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR body for beta cherry-pick consideration
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
echo "Cherry-pick consideration box is checked. Check passed."
exit 0
fi
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
exit 1

View File

@@ -45,6 +45,9 @@ env:
# TODO: debug why this is failing and enable
CODE_INTERPRETER_BASE_URL: http://localhost:8000
# OpenSearch
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
@@ -115,9 +118,9 @@ jobs:
- name: Create .env file for Docker Compose
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
COMPOSE_PROFILES=s3-filestore
CODE_INTERPRETER_BETA_ENABLED=true
DISABLE_TELEMETRY=true
OPENSEARCH_FOR_ONYX_ENABLED=true
EOF
- name: Set up Standard Dependencies
@@ -126,6 +129,7 @@ jobs:
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
-f docker-compose.opensearch.yml \
up -d \
minio \
relational_db \

View File

@@ -91,7 +91,6 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo update
- name: Install Redis operator

View File

@@ -20,7 +20,6 @@ env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
SLACK_BOT_TOKEN_TEST_SPACE: ${{ secrets.SLACK_BOT_TOKEN_TEST_SPACE }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
@@ -424,7 +423,6 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
@@ -445,7 +443,6 @@ jobs:
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
-e ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${{ matrix.edition == 'ee' && 'true' || 'false' }} \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/${{ matrix.test-dir.path }}
@@ -704,7 +701,6 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \

View File

@@ -548,7 +548,7 @@ class in the utils over directly calling the APIs with a library like `requests`
calling the utilities directly (e.g. do NOT create admin users with
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
A great example of this type of test is `backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py`.
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
To run them:
@@ -616,9 +616,3 @@ This is a minimal list - feel free to include more. Do NOT write code as part of
Keep it high level. You can reference certain files or functions though.
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
## Best Practices
In addition to the other content in this file, best practices for contributing
to the codebase can be found at `contributing_guides/best_practices.md`.
Understand its contents and follow them.

View File

@@ -21,14 +21,15 @@ import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import NamedTuple
from typing import List, NamedTuple
from alembic.config import Config
from alembic.script import ScriptDirectory
from sqlalchemy import text
from onyx.db.engine.sql_engine import is_valid_schema_name
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
from shared_configs.configs import TENANT_ID_PREFIX
@@ -104,6 +105,56 @@ def get_head_revision() -> str | None:
return script.get_current_head()
def get_schemas_needing_migration(
tenant_schemas: List[str], head_rev: str
) -> List[str]:
"""Return only schemas whose current alembic version is not at head."""
if not tenant_schemas:
return []
engine = SqlEngine.get_engine()
with engine.connect() as conn:
# Find which schemas actually have an alembic_version table
rows = conn.execute(
text(
"SELECT table_schema FROM information_schema.tables "
"WHERE table_name = 'alembic_version' "
"AND table_schema = ANY(:schemas)"
),
{"schemas": tenant_schemas},
)
schemas_with_table = set(row[0] for row in rows)
# Schemas without the table definitely need migration
needs_migration = [s for s in tenant_schemas if s not in schemas_with_table]
if not schemas_with_table:
return needs_migration
# Validate schema names before interpolating into SQL
for schema in schemas_with_table:
if not is_valid_schema_name(schema):
raise ValueError(f"Invalid schema name: {schema}")
# Single query to get every schema's current revision at once.
# Use integer tags instead of interpolating schema names into
# string literals to avoid quoting issues.
schema_list = list(schemas_with_table)
union_parts = [
f'SELECT {i} AS idx, version_num FROM "{schema}".alembic_version'
for i, schema in enumerate(schema_list)
]
rows = conn.execute(text(" UNION ALL ".join(union_parts)))
version_by_schema = {schema_list[row[0]]: row[1] for row in rows}
needs_migration.extend(
s for s in schemas_with_table if version_by_schema.get(s) != head_rev
)
return needs_migration
def run_migrations_parallel(
schemas: list[str],
max_workers: int,

View File

@@ -1,29 +0,0 @@
"""code interpreter seed
Revision ID: 07b98176f1de
Revises: 7cb492013621
Create Date: 2026-02-23 15:55:07.606784
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "07b98176f1de"
down_revision = "7cb492013621"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Seed the single instance of code_interpreter_server
# NOTE: There should only exist at most and at minimum 1 code_interpreter_server row
op.execute(
sa.text("INSERT INTO code_interpreter_server (server_enabled) VALUES (true)")
)
def downgrade() -> None:
op.execute(sa.text("DELETE FROM code_interpreter_server"))

View File

@@ -1,28 +0,0 @@
"""add scim_username to scim_user_mapping
Revision ID: 0bb4558f35df
Revises: 631fd2504136
Create Date: 2026-02-20 10:45:30.340188
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0bb4558f35df"
down_revision = "631fd2504136"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("scim_username", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_username")

View File

@@ -1,48 +0,0 @@
"""add enterprise and name fields to scim_user_mapping
Revision ID: 7616121f6e97
Revises: 07b98176f1de
Create Date: 2026-02-23 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7616121f6e97"
down_revision = "07b98176f1de"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("department", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("manager", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("given_name", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("family_name", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("scim_emails_json", sa.Text(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_emails_json")
op.drop_column("scim_user_mapping", "family_name")
op.drop_column("scim_user_mapping", "given_name")
op.drop_column("scim_user_mapping", "manager")
op.drop_column("scim_user_mapping", "department")

View File

@@ -1,31 +0,0 @@
"""code interpreter server model
Revision ID: 7cb492013621
Revises: 0bb4558f35df
Create Date: 2026-02-22 18:54:54.007265
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7cb492013621"
down_revision = "0bb4558f35df"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"code_interpreter_server",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column(
"server_enabled", sa.Boolean, nullable=False, server_default=sa.true()
),
)
def downgrade() -> None:
op.drop_table("code_interpreter_server")

View File

@@ -127,14 +127,9 @@ class ScimDAL(DAL):
self,
external_id: str,
user_id: UUID,
scim_username: str | None = None,
) -> ScimUserMapping:
"""Create a mapping between a SCIM externalId and an Onyx user."""
mapping = ScimUserMapping(
external_id=external_id,
user_id=user_id,
scim_username=scim_username,
)
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
self._session.add(mapping)
self._session.flush()
return mapping
@@ -253,11 +248,11 @@ class ScimDAL(DAL):
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[User, ScimUserMapping | None]], int]:
) -> tuple[list[tuple[User, str | None]], int]:
"""Query users with optional SCIM filter and pagination.
Returns:
A tuple of (list of (user, mapping) pairs, total_count).
A tuple of (list of (user, external_id) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
@@ -297,104 +292,33 @@ class ScimDAL(DAL):
users = list(
self._session.scalars(
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
)
.unique()
.all()
).all()
)
# Batch-fetch SCIM mappings to avoid N+1 queries
mapping_map = self._get_user_mappings_batch([u.id for u in users])
return [(u, mapping_map.get(u.id)) for u in users], total
# Batch-fetch external IDs to avoid N+1 queries
ext_id_map = self._get_user_external_ids([u.id for u in users])
return [(u, ext_id_map.get(u.id)) for u in users], total
def sync_user_external_id(
self,
user_id: UUID,
new_external_id: str | None,
scim_username: str | None = None,
) -> None:
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
"""Create, update, or delete the external ID mapping for a user."""
mapping = self.get_user_mapping_by_user_id(user_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
else:
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
)
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
elif mapping:
self.delete_user_mapping(mapping.id)
def _get_user_mappings_batch(
self, user_ids: list[UUID]
) -> dict[UUID, ScimUserMapping]:
"""Batch-fetch SCIM user mappings keyed by user ID."""
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
"""Batch-fetch external IDs for a list of user IDs."""
if not user_ids:
return {}
mappings = self._session.scalars(
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
).all()
return {m.user_id: m for m in mappings}
def get_user_groups(self, user_id: UUID) -> list[tuple[int, str]]:
"""Get groups a user belongs to as ``(group_id, group_name)`` pairs.
Excludes groups marked for deletion.
"""
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_id == user_id)
).all()
group_ids = [r.user_group_id for r in rels]
if not group_ids:
return []
groups = self._session.scalars(
select(UserGroup).where(
UserGroup.id.in_(group_ids),
UserGroup.is_up_for_deletion.is_(False),
)
).all()
return [(g.id, g.name) for g in groups]
def get_users_groups_batch(
self, user_ids: list[UUID]
) -> dict[UUID, list[tuple[int, str]]]:
"""Batch-fetch group memberships for multiple users.
Returns a mapping of ``user_id → [(group_id, group_name), ...]``.
Avoids N+1 queries when building user list responses.
"""
if not user_ids:
return {}
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_id.in_(user_ids))
).all()
group_ids = list({r.user_group_id for r in rels})
if not group_ids:
return {}
groups = self._session.scalars(
select(UserGroup).where(
UserGroup.id.in_(group_ids),
UserGroup.is_up_for_deletion.is_(False),
)
).all()
groups_by_id = {g.id: g.name for g in groups}
result: dict[UUID, list[tuple[int, str]]] = {}
for r in rels:
if r.user_id and r.user_group_id in groups_by_id:
result.setdefault(r.user_id, []).append(
(r.user_group_id, groups_by_id[r.user_group_id])
)
return result
return {m.user_id: m.external_id for m in mappings}
# ------------------------------------------------------------------
# Group mapping operations
@@ -559,13 +483,9 @@ class ScimDAL(DAL):
if not user_ids:
return []
users = (
self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
)
.unique()
.all()
)
users = self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
).all()
users_by_id = {u.id: u for u in users}
return [
@@ -584,13 +504,9 @@ class ScimDAL(DAL):
"""
if not uuids:
return []
existing_users = (
self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
)
.unique()
.all()
)
existing_users = self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
).all()
existing_ids = {u.id for u in existing_users}
return [uid for uid in uuids if uid not in existing_ids]

View File

@@ -9,7 +9,6 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
@@ -19,15 +18,11 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
from onyx.db.models import User
@@ -200,60 +195,8 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
return db_session.scalar(stmt)
def _add_user_group_snapshot_eager_loads(
stmt: Select,
) -> Select:
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
return stmt.options(
selectinload(UserGroup.users),
selectinload(UserGroup.user_group_relationships),
selectinload(UserGroup.cc_pair_relationships)
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
.options(
selectinload(ConnectorCredentialPair.connector),
selectinload(ConnectorCredentialPair.credential).selectinload(
Credential.user
),
),
selectinload(UserGroup.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(UserGroup.personas).options(
selectinload(Persona.tools),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
selectinload(Persona.groups),
),
)
def fetch_user_groups(
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
db_session: Session, only_up_to_date: bool = True
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -266,8 +209,6 @@ def fetch_user_groups(
db_session (Session): The SQLAlchemy session used to query the database.
only_up_to_date (bool, optional): Flag to determine whether to filter the results
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -275,16 +216,11 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def fetch_user_groups_for_user(
db_session: Session,
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
db_session: Session, user_id: UUID, only_curator_groups: bool = False
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -294,9 +230,7 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def construct_document_id_select_by_usergroup(

View File

@@ -1,13 +1,9 @@
from collections.abc import Generator
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@@ -47,27 +43,14 @@ def sharepoint_group_sync(
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
enumerate_all = connector_config.get(
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
)
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
sp_domain_suffix = connector.sharepoint_domain_suffix
# Process each site
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
)
ctx = connector._create_rest_client_context(site_descriptor.url)
external_groups = get_sharepoint_external_groups(
ctx,
connector.graph_client,
graph_api_base=connector.graph_api_base,
get_access_token=connector._get_graph_access_token,
enumerate_all_ad_groups=enumerate_all,
)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
# Yield each group
for group in external_groups:

View File

@@ -1,13 +1,9 @@
import re
import time
from collections import deque
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse
import requests as _requests
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
@@ -18,10 +14,7 @@ from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
from onyx.connectors.sharepoint.connector import sleep_and_retry
from onyx.utils.logger import setup_logger
@@ -40,70 +33,6 @@ LIMITED_ACCESS_ROLE_TYPES = [1, 9]
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
def _graph_api_get(
url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Authenticated Graph API GET with retry on transient errors."""
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
access_token = get_access_token()
headers = {"Authorization": f"Bearer {access_token}"}
try:
resp = _requests.get(
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
)
if (
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
and attempt < GRAPH_API_MAX_RETRIES
):
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
logger.warning(
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
resp.raise_for_status()
return resp.json()
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
if attempt < GRAPH_API_MAX_RETRIES:
wait = min(2**attempt, 60)
logger.warning(
f"Graph API connection error on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
raise
raise RuntimeError(
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
)
def _iter_graph_collection(
initial_url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Paginate through a Graph API collection, yielding items one at a time."""
url: str | None = initial_url
while url:
data = _graph_api_get(url, get_access_token, params)
params = None
yield from data.get("value", [])
url = data.get("@odata.nextLink")
def _normalize_email(email: str) -> str:
if MICROSOFT_DOMAIN in email:
return email.replace(MICROSOFT_DOMAIN, "")
return email
class SharepointGroup(BaseModel):
model_config = {"frozen": True}
@@ -643,65 +572,8 @@ def get_external_access_from_sharepoint(
)
def _enumerate_ad_groups_paginated(
get_access_token: Callable[[], str],
already_resolved: set[str],
graph_api_base: str,
) -> Generator[ExternalUserGroup, None, None]:
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
Skips groups whose suffixed name is already in *already_resolved*.
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
"""
groups_url = f"{graph_api_base}/groups"
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
total_groups = 0
for group_json in _iter_graph_collection(
groups_url, get_access_token, groups_params
):
group_id: str = group_json.get("id", "")
display_name: str = group_json.get("displayName", "")
if not group_id or not display_name:
continue
total_groups += 1
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
logger.warning(
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
"groups — stopping to avoid excessive memory/API usage. "
"Remaining groups will be resolved from role assignments only."
)
return
name = f"{display_name}_{group_id}"
if name in already_resolved:
continue
member_emails: list[str] = []
members_url = f"{graph_api_base}/groups/{group_id}/members"
members_params: dict[str, str] = {
"$select": "userPrincipalName,mail",
"$top": "999",
}
for member_json in _iter_graph_collection(
members_url, get_access_token, members_params
):
email = member_json.get("userPrincipalName") or member_json.get("mail")
if email:
member_emails.append(_normalize_email(email))
yield ExternalUserGroup(id=name, user_emails=member_emails)
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
def get_sharepoint_external_groups(
client_context: ClientContext,
graph_client: GraphClient,
graph_api_base: str,
get_access_token: Callable[[], str] | None = None,
enumerate_all_ad_groups: bool = False,
client_context: ClientContext, graph_client: GraphClient
) -> list[ExternalUserGroup]:
groups: set[SharepointGroup] = set()
@@ -757,22 +629,57 @@ def get_sharepoint_external_groups(
client_context, graph_client, groups, is_group_sync=True
)
external_user_groups: list[ExternalUserGroup] = [
ExternalUserGroup(id=group_name, user_emails=list(emails))
for group_name, emails in groups_and_members.groups_to_emails.items()
]
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(
graph_client.groups.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups",
)
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
ad_groups_to_emails: dict[str, set[str]] = {}
for group in azure_ad_groups:
# If the group is already identified, we don't need to get the members
if group.display_name in identified_groups:
continue
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
name = group.display_name
name = _get_group_name_with_suffix(group.id, name, graph_client)
if not enumerate_all_ad_groups or get_access_token is None:
logger.info(
"Skipping exhaustive Azure AD group enumeration. "
"Only groups found in site role assignments are included."
members = sleep_and_retry(
group.members.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
)
return external_user_groups
for member in members:
member_data = member.to_json()
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
if not ad_groups_to_emails.get(name):
ad_groups_to_emails[name] = set()
if user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
ad_groups_to_emails[name].add(user_principal_name)
elif mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
ad_groups_to_emails[name].add(mail)
already_resolved = set(groups_and_members.groups_to_emails.keys())
for group in _enumerate_ad_groups_paginated(
get_access_token, already_resolved, graph_api_base
):
external_user_groups.append(group)
external_user_groups: list[ExternalUserGroup] = []
for group_name, emails in groups_and_members.groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
for group_name, emails in ad_groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
return external_user_groups

View File

@@ -34,7 +34,7 @@ class SendSearchQueryRequest(BaseModel):
filters: BaseFilters | None = None
num_docs_fed_to_llm_selection: int | None = None
run_query_expansion: bool = False
num_hits: int = 30
num_hits: int = 50
include_content: bool = False
stream: bool = False

View File

@@ -26,10 +26,12 @@ from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimError
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimResourceType
@@ -39,8 +41,6 @@ from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.providers.base import get_default_provider
from ee.onyx.server.scim.providers.base import ScimProvider
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
@@ -53,6 +53,7 @@ from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
@@ -62,18 +63,6 @@ scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
def _get_provider(
_token: ScimToken = Depends(verify_scim_token),
) -> ScimProvider:
"""Resolve the SCIM provider for the current request.
Currently returns OktaProvider for all requests. When multi-provider
support is added (ENG-3652), this will resolve based on token metadata
or tenant configuration — no endpoint changes required.
"""
return get_default_provider()
# ---------------------------------------------------------------------------
# Service Discovery Endpoints (unauthenticated)
# ---------------------------------------------------------------------------
@@ -111,6 +100,28 @@ def _scim_error_response(status: int, detail: str) -> JSONResponse:
)
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
"""Convert an Onyx User to a SCIM User resource representation."""
name = None
if user.personal_name:
parts = user.personal_name.split(" ", 1)
name = ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
return ScimUserResource(
id=str(user.id),
externalId=external_id,
userName=user.email,
name=name,
emails=[ScimEmail(value=user.email, type="work", primary=True)],
active=user.is_active,
meta=ScimMeta(resourceType="User"),
)
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
check_fn = fetch_ee_implementation_or_noop(
@@ -144,10 +155,9 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
"""
if not name:
return None
# Build from givenName/familyName first — IdPs like Okta may send a stale
# ``formatted`` value while updating the individual name components.
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
return parts or name.formatted
return name.formatted or " ".join(
part for part in [name.givenName, name.familyName] if part
)
# ---------------------------------------------------------------------------
@@ -161,7 +171,6 @@ def list_users(
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List users with optional SCIM filter and pagination."""
@@ -174,19 +183,12 @@ def list_users(
return _scim_error_response(400, str(e))
try:
users_with_mappings, total = dal.list_users(scim_filter, startIndex, count)
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
user_groups_map = dal.get_users_groups_batch([u.id for u, _ in users_with_mappings])
resources: list[ScimUserResource | ScimGroupResource] = [
provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=user_groups_map.get(user.id, []),
scim_username=mapping.scim_username if mapping else None,
)
for user, mapping in users_with_mappings
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
]
return ScimListResponse(
@@ -201,7 +203,6 @@ def list_users(
def get_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Get a single user by ID."""
@@ -214,26 +215,20 @@ def get_user(
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
return provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=dal.get_user_groups(user.id),
scim_username=mapping.scim_username if mapping else None,
)
return _user_to_scim(user, mapping.external_id if mapping else None)
@scim_router.post("/Users", status_code=201, response_model=None)
def create_user(
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Create a new user from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
email = user_resource.userName.strip()
email = user_resource.userName.strip().lower()
# externalId is how the IdP correlates this user on subsequent requests.
# Without it, the IdP can't find the user and will try to re-create,
@@ -269,14 +264,11 @@ def create_user(
# Create SCIM mapping (externalId is validated above, always present)
external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
dal.create_user_mapping(
external_id=external_id, user_id=user.id, scim_username=scim_username
)
dal.create_user_mapping(external_id=external_id, user_id=user.id)
dal.commit()
return provider.build_user_resource(user, external_id, scim_username=scim_username)
return _user_to_scim(user, external_id)
@scim_router.put("/Users/{user_id}", response_model=None)
@@ -284,7 +276,6 @@ def replace_user(
user_id: str,
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Replace a user entirely (RFC 7644 §3.5.1)."""
@@ -302,27 +293,19 @@ def replace_user(
if seat_error:
return _scim_error_response(403, seat_error)
personal_name = _scim_name_to_str(user_resource.name)
dal.update_user(
user,
email=user_resource.userName.strip(),
email=user_resource.userName.strip().lower(),
is_active=user_resource.active,
personal_name=personal_name,
personal_name=_scim_name_to_str(user_resource.name),
)
new_external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
dal.sync_user_external_id(user.id, new_external_id, scim_username=scim_username)
dal.sync_user_external_id(user.id, new_external_id)
dal.commit()
return provider.build_user_resource(
user,
new_external_id,
groups=dal.get_user_groups(user.id),
scim_username=scim_username,
)
return _user_to_scim(user, new_external_id)
@scim_router.patch("/Users/{user_id}", response_model=None)
@@ -330,7 +313,6 @@ def patch_user(
user_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Partially update a user (RFC 7644 §3.5.2).
@@ -348,19 +330,11 @@ def patch_user(
mapping = dal.get_user_mapping_by_user_id(user.id)
external_id = mapping.external_id if mapping else None
current_scim_username = mapping.scim_username if mapping else None
current = provider.build_user_resource(
user,
external_id,
groups=dal.get_user_groups(user.id),
scim_username=current_scim_username,
)
current = _user_to_scim(user, external_id)
try:
patched = apply_user_patch(
patch_request.Operations, current, provider.ignored_patch_paths
)
patched = apply_user_patch(patch_request.Operations, current)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
@@ -371,40 +345,22 @@ def patch_user(
if seat_error:
return _scim_error_response(403, seat_error)
# Track the scim_username — if userName was patched, update it
new_scim_username = patched.userName.strip() if patched.userName else None
# If displayName was explicitly patched (different from the original), use
# it as personal_name directly. Otherwise, derive from name components.
personal_name: str | None
if patched.displayName and patched.displayName != current.displayName:
personal_name = patched.displayName
else:
personal_name = _scim_name_to_str(patched.name)
dal.update_user(
user,
email=(
patched.userName.strip()
if patched.userName.strip().lower() != user.email.lower()
patched.userName.strip().lower()
if patched.userName.lower() != user.email
else None
),
is_active=patched.active if patched.active != user.is_active else None,
personal_name=personal_name,
personal_name=_scim_name_to_str(patched.name),
)
dal.sync_user_external_id(
user.id, patched.externalId, scim_username=new_scim_username
)
dal.sync_user_external_id(user.id, patched.externalId)
dal.commit()
return provider.build_user_resource(
user,
patched.externalId,
groups=dal.get_user_groups(user.id),
scim_username=new_scim_username,
)
return _user_to_scim(user, patched.externalId)
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
@@ -442,6 +398,24 @@ def delete_user(
# ---------------------------------------------------------------------------
def _group_to_scim(
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Convert an Onyx UserGroup to a SCIM Group resource."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
"""Parse *group_id* as int, look up the group, or return a 404 error."""
try:
@@ -500,7 +474,6 @@ def list_groups(
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List groups with optional SCIM filter and pagination."""
@@ -518,7 +491,7 @@ def list_groups(
return _scim_error_response(400, str(e))
resources: list[ScimUserResource | ScimGroupResource] = [
provider.build_group_resource(group, dal.get_group_members(group.id), ext_id)
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
for group, ext_id in groups_with_ext_ids
]
@@ -534,7 +507,6 @@ def list_groups(
def get_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Get a single group by ID."""
@@ -549,16 +521,13 @@ def get_group(
mapping = dal.get_group_mapping_by_group_id(group.id)
members = dal.get_group_members(group.id)
return provider.build_group_resource(
group, members, mapping.external_id if mapping else None
)
return _group_to_scim(group, members, mapping.external_id if mapping else None)
@scim_router.post("/Groups", status_code=201, response_model=None)
def create_group(
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Create a new group from a SCIM provisioning request."""
@@ -596,7 +565,7 @@ def create_group(
dal.commit()
members = dal.get_group_members(db_group.id)
return provider.build_group_resource(db_group, members, external_id)
return _group_to_scim(db_group, members, external_id)
@scim_router.put("/Groups/{group_id}", response_model=None)
@@ -604,7 +573,6 @@ def replace_group(
group_id: str,
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Replace a group entirely (RFC 7644 §3.5.1)."""
@@ -627,7 +595,7 @@ def replace_group(
dal.commit()
members = dal.get_group_members(group.id)
return provider.build_group_resource(group, members, group_resource.externalId)
return _group_to_scim(group, members, group_resource.externalId)
@scim_router.patch("/Groups/{group_id}", response_model=None)
@@ -635,7 +603,6 @@ def patch_group(
group_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Partially update a group (RFC 7644 §3.5.2).
@@ -654,11 +621,11 @@ def patch_group(
external_id = mapping.external_id if mapping else None
current_members = dal.get_group_members(group.id)
current = provider.build_group_resource(group, current_members, external_id)
current = _group_to_scim(group, current_members, external_id)
try:
patched, added_ids, removed_ids = apply_group_patch(
patch_request.Operations, current, provider.ignored_patch_paths
patch_request.Operations, current
)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
@@ -685,7 +652,7 @@ def patch_group(
dal.commit()
members = dal.get_group_members(group.id)
return provider.build_group_resource(group, members, patched.externalId)
return _group_to_scim(group, members, patched.externalId)
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)

View File

@@ -63,13 +63,6 @@ class ScimMeta(BaseModel):
location: str | None = None
class ScimUserGroupRef(BaseModel):
"""Group reference within a User resource (RFC 7643 §4.1.2, read-only)."""
value: str
display: str | None = None
class ScimUserResource(BaseModel):
"""SCIM User resource representation (RFC 7643 §4.1).
@@ -83,10 +76,8 @@ class ScimUserResource(BaseModel):
externalId: str | None = None # IdP's identifier for this user
userName: str # Typically the user's email address
name: ScimName | None = None
displayName: str | None = None
emails: list[ScimEmail] = Field(default_factory=list)
active: bool = True
groups: list[ScimUserGroupRef] = Field(default_factory=list)
meta: ScimMeta | None = None
@@ -130,40 +121,12 @@ class ScimPatchOperationType(str, Enum):
REMOVE = "remove"
class ScimPatchResourceValue(BaseModel):
"""Partial resource dict for path-less PATCH replace operations.
When an IdP sends a PATCH without a ``path``, the ``value`` is a dict
of resource attributes to set. IdPs may include read-only fields
(``id``, ``schemas``, ``meta``) alongside actual changes — these are
stripped by the provider's ``ignored_patch_paths`` before processing.
``extra="allow"`` lets unknown attributes pass through so the patch
handler can decide what to do with them (ignore or reject).
"""
model_config = ConfigDict(extra="allow")
active: bool | None = None
userName: str | None = None
displayName: str | None = None
externalId: str | None = None
name: ScimName | None = None
members: list[ScimGroupMember] | None = None
id: str | None = None
schemas: list[str] | None = None
meta: ScimMeta | None = None
ScimPatchValue = str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None
class ScimPatchOperation(BaseModel):
"""Single PATCH operation (RFC 7644 §3.5.2)."""
op: ScimPatchOperationType
path: str | None = None
value: ScimPatchValue = None
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
class ScimPatchRequest(BaseModel):

View File

@@ -16,12 +16,9 @@ from __future__ import annotations
import re
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchResourceValue
from ee.onyx.server.scim.models import ScimPatchValue
from ee.onyx.server.scim.models import ScimUserResource
@@ -44,15 +41,9 @@ _MEMBER_FILTER_RE = re.compile(
def apply_user_patch(
operations: list[ScimPatchOperation],
current: ScimUserResource,
ignored_paths: frozenset[str] = frozenset(),
) -> ScimUserResource:
"""Apply SCIM PATCH operations to a user resource.
Args:
operations: The PATCH operations to apply.
current: The current user resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns a new ``ScimUserResource`` with the modifications applied.
The original object is not mutated.
@@ -64,9 +55,9 @@ def apply_user_patch(
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_user_replace(op, data, name_data, ignored_paths)
_apply_user_replace(op, data, name_data)
elif op.op == ScimPatchOperationType.ADD:
_apply_user_replace(op, data, name_data, ignored_paths)
_apply_user_replace(op, data, name_data)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on User resource"
@@ -80,34 +71,30 @@ def _apply_user_replace(
op: ScimPatchOperation,
data: dict,
name_data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace/add operation to user data."""
path = (op.path or "").lower()
if not path:
# No path — value is a resource dict of top-level attributes to set
if isinstance(op.value, ScimPatchResourceValue):
for key, val in op.value.model_dump(exclude_unset=True).items():
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
# No path — value is a dict of top-level attributes to set
if isinstance(op.value, dict):
for key, val in op.value.items():
_set_user_field(key.lower(), val, data, name_data)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
_set_user_field(path, op.value, data, name_data, ignored_paths)
_set_user_field(path, op.value, data, name_data)
def _set_user_field(
path: str,
value: ScimPatchValue,
value: str | bool | dict | list | None,
data: dict,
name_data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Set a single field on user data by SCIM path."""
if path in ignored_paths:
return
elif path == "active":
if path == "active":
data["active"] = value
elif path == "username":
data["userName"] = value
@@ -120,7 +107,7 @@ def _set_user_field(
elif path == "name.formatted":
name_data["formatted"] = value
elif path == "displayname":
data["displayName"] = value
# Some IdPs send displayName on users; map to formatted name
name_data["formatted"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
@@ -129,15 +116,9 @@ def _set_user_field(
def apply_group_patch(
operations: list[ScimPatchOperation],
current: ScimGroupResource,
ignored_paths: frozenset[str] = frozenset(),
) -> tuple[ScimGroupResource, list[str], list[str]]:
"""Apply SCIM PATCH operations to a group resource.
Args:
operations: The PATCH operations to apply.
current: The current group resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns:
A tuple of (modified group, added member IDs, removed member IDs).
The caller uses the member ID lists to update the database.
@@ -152,9 +133,7 @@ def apply_group_patch(
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_group_replace(
op, data, current_members, added_ids, removed_ids, ignored_paths
)
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
elif op.op == ScimPatchOperationType.ADD:
_apply_group_add(op, current_members, added_ids)
elif op.op == ScimPatchOperationType.REMOVE:
@@ -175,48 +154,38 @@ def _apply_group_replace(
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace operation to group data."""
path = (op.path or "").lower()
if not path:
if isinstance(op.value, ScimPatchResourceValue):
dumped = op.value.model_dump(exclude_unset=True)
for key, val in dumped.items():
if isinstance(op.value, dict):
for key, val in op.value.items():
if key.lower() == "members":
_replace_members(val, current_members, added_ids, removed_ids)
else:
_set_group_field(key.lower(), val, data, ignored_paths)
_set_group_field(key.lower(), val, data)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
if path == "members":
_replace_members(
_members_to_dicts(op.value), current_members, added_ids, removed_ids
)
_replace_members(op.value, current_members, added_ids, removed_ids)
return
_set_group_field(path, op.value, data, ignored_paths)
def _members_to_dicts(
value: str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None,
) -> list[dict]:
"""Convert a member list value to a list of dicts for internal processing."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
return [m.model_dump(exclude_none=True) for m in value]
_set_group_field(path, op.value, data)
def _replace_members(
value: list[dict],
value: str | list | dict | bool | None,
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
) -> None:
"""Replace the entire group member list."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
old_ids = {m["value"] for m in current_members}
new_ids = {m.get("value", "") for m in value}
@@ -228,14 +197,11 @@ def _replace_members(
def _set_group_field(
path: str,
value: ScimPatchValue,
value: str | bool | dict | list | None,
data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Set a single field on group data by SCIM path."""
if path in ignored_paths:
return
elif path == "displayname":
if path == "displayname":
data["displayName"] = value
elif path == "externalid":
data["externalId"] = value
@@ -257,10 +223,8 @@ def _apply_group_add(
if not isinstance(op.value, list):
raise ScimPatchError("Add members requires a list value")
member_dicts = [m.model_dump(exclude_none=True) for m in op.value]
existing_ids = {m["value"] for m in members}
for member_data in member_dicts:
for member_data in op.value:
member_id = member_data.get("value", "")
if member_id and member_id not in existing_ids:
members.append(member_data)

View File

@@ -1,123 +0,0 @@
"""Base SCIM provider abstraction."""
from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from uuid import UUID
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserGroupRef
from ee.onyx.server.scim.models import ScimUserResource
from onyx.db.models import User
from onyx.db.models import UserGroup
class ScimProvider(ABC):
"""Base class for provider-specific SCIM behavior.
Subclass this to handle IdP-specific quirks. The base class provides
RFC 7643-compliant response builders that populate all standard fields.
"""
@property
@abstractmethod
def name(self) -> str:
"""Short identifier for this provider (e.g. ``"okta"``)."""
...
@property
@abstractmethod
def ignored_patch_paths(self) -> frozenset[str]:
"""SCIM attribute paths to silently skip in PATCH value-object dicts.
IdPs may include read-only or meta fields alongside actual changes
(e.g. Okta sends ``{"id": "...", "active": false}``). Paths listed
here are silently dropped instead of raising an error.
"""
...
def build_user_resource(
self,
user: User,
external_id: str | None = None,
groups: list[tuple[int, str]] | None = None,
scim_username: str | None = None,
) -> ScimUserResource:
"""Build a SCIM User response from an Onyx User.
Args:
user: The Onyx user model.
external_id: The IdP's external identifier for this user.
groups: List of ``(group_id, group_name)`` tuples for the
``groups`` read-only attribute. Pass ``None`` or ``[]``
for newly-created users.
scim_username: The original-case userName from the IdP. Falls
back to ``user.email`` (lowercase) when not available.
"""
group_refs = [
ScimUserGroupRef(value=str(gid), display=gname)
for gid, gname in (groups or [])
]
# Use original-case userName if stored, otherwise fall back to the
# lowercased email from the User model.
username = scim_username or user.email
return ScimUserResource(
id=str(user.id),
externalId=external_id,
userName=username,
name=self._build_scim_name(user),
displayName=user.personal_name,
emails=[ScimEmail(value=username, type="work", primary=True)],
active=user.is_active,
groups=group_refs,
meta=ScimMeta(resourceType="User"),
)
def build_group_resource(
self,
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Build a SCIM Group response from an Onyx UserGroup."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
@staticmethod
def _build_scim_name(user: User) -> ScimName | None:
"""Extract SCIM name components from a user's personal name."""
if not user.personal_name:
return None
parts = user.personal_name.split(" ", 1)
return ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
def get_default_provider() -> ScimProvider:
"""Return the default SCIM provider.
Currently returns ``OktaProvider`` since Okta is the primary supported
IdP. When provider detection is added (via token metadata or tenant
config), this can be replaced with dynamic resolution.
"""
from ee.onyx.server.scim.providers.okta import OktaProvider
return OktaProvider()

View File

@@ -1,25 +0,0 @@
"""Okta SCIM provider."""
from __future__ import annotations
from ee.onyx.server.scim.providers.base import ScimProvider
class OktaProvider(ScimProvider):
"""Okta SCIM provider.
Okta behavioral notes:
- Uses ``PATCH {"active": false}`` for deprovisioning (not DELETE)
- Sends path-less PATCH with value dicts containing extra fields
(``id``, ``schemas``)
- Expects ``displayName`` and ``groups`` in user responses
- Only uses ``eq`` operator for ``userName`` filter
"""
@property
def name(self) -> str:
return "okta"
@property
def ignored_patch_paths(self) -> frozenset[str]:
return frozenset({"id", "schemas", "meta"})

View File

@@ -37,15 +37,12 @@ def list_user_groups(
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session, only_up_to_date=False, eager_load_for_snapshot=True
)
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]

View File

@@ -53,8 +53,7 @@ class UserGroup(BaseModel):
id=cc_pair_relationship.cc_pair.id,
name=cc_pair_relationship.cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_relationship.cc_pair.connector,
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
cc_pair_relationship.cc_pair.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential

View File

@@ -277,32 +277,13 @@ def verify_email_domain(email: str) -> None:
detail="Email is not valid",
)
local_part, domain = email.split("@")
domain = domain.lower()
if AUTH_TYPE == AuthType.CLOUD:
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
if domain == "googlemail.com":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
)
if "+" in local_part and domain != "onyx.app":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
},
)
domain = email.split("@")[-1].lower()
# Check if email uses a disposable/temporary domain
if is_disposable_email(email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
},
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
)
# Check domain whitelist if configured
@@ -1690,10 +1671,7 @@ def get_oauth_router(
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
# Use WEB_DOMAIN instead of request.url_for() to prevent host
# header poisoning — request.url_for() trusts the Host header.
callback_path = request.app.url_path_for(callback_route_name)
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")

View File

@@ -22,7 +22,6 @@ from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import METADATA_SUFFIX
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER
@@ -59,7 +58,6 @@ FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
METADATA_SUFFIX,
DOCUMENT_SETS,
USER_PROJECT,
PERSONAS,
PRIMARY_OWNERS,
SECONDARY_OWNERS,
ACCESS_CONTROL_LIST,
@@ -278,7 +276,6 @@ def transform_vespa_chunks_to_opensearch_chunks(
)
)
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
personas: list[int] | None = vespa_chunk.get(PERSONAS)
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
@@ -328,7 +325,6 @@ def transform_vespa_chunks_to_opensearch_chunks(
metadata_suffix=metadata_suffix,
document_sets=document_sets,
user_projects=user_projects,
personas=personas,
primary_owners=primary_owners,
secondary_owners=secondary_owners,
tenant_id=tenant_state,

View File

@@ -5,17 +5,14 @@ from uuid import UUID
import httpx
import sqlalchemy as sa
from celery import Celery
from celery import shared_task
from celery import Task
from redis import Redis
from redis.lock import Lock as RedisLock
from retry import retry
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import DISABLE_VECTOR_DB
@@ -24,16 +21,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
@@ -64,73 +57,14 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
def _user_file_project_sync_queued_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_QUEUED_PREFIX}:{user_file_id}"
def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
return celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
)
def enqueue_user_file_project_sync_task(
*,
celery_app: Celery,
redis_client: Redis,
user_file_id: str | UUID,
tenant_id: str,
priority: OnyxCeleryPriority = OnyxCeleryPriority.HIGH,
) -> bool:
"""Enqueue a project-sync task if no matching queued task already exists."""
queued_key = _user_file_project_sync_queued_key(user_file_id)
# NX+EX gives us atomic dedupe and a self-healing TTL.
queued_guard_set = redis_client.set(
queued_key,
1,
nx=True,
ex=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
)
if not queued_guard_set:
return False
try:
celery_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=priority,
expires=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
)
except Exception:
# Roll back the queued guard if task publish fails.
redis_client.delete(queued_key)
raise
return True
@retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0))
def _visit_chunks(
*,
@@ -186,24 +120,7 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
Uses direct Redis locks to avoid overlapping runs.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -218,21 +135,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -245,35 +148,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -281,8 +161,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@@ -425,12 +304,6 @@ def process_single_user_file(
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
@@ -684,8 +557,8 @@ def process_single_user_file_delete(
ignore_result=True,
)
def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
"""Scan for user files needing project sync and enqueue per-file tasks."""
task_logger.info("Starting")
"""Scan for user files with PROJECT_SYNC status and enqueue per-file tasks."""
task_logger.info("check_for_user_file_project_sync - Starting")
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
@@ -697,16 +570,7 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
queue_depth = get_user_file_project_sync_queue_depth(self.app)
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
task_logger.warning(
f"Queue depth {queue_depth} exceeds "
f"{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}, skipping enqueue for tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -722,23 +586,19 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
if not enqueue_user_file_project_sync_task(
celery_app=self.app,
redis_client=redis_client,
user_file_id=user_file_id,
tenant_id=tenant_id,
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=OnyxCeleryPriority.HIGH,
):
skipped_guard += 1
continue
)
enqueued += 1
finally:
if lock.owned():
lock.release()
task_logger.info(
f"Enqueued {enqueued} "
f"Skipped guard {skipped_guard} tasks for tenant={tenant_id}"
f"check_for_user_file_project_sync - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@@ -757,8 +617,6 @@ def process_single_user_file_project_sync(
)
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,

View File

@@ -58,8 +58,6 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
@@ -158,7 +156,36 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
logger.warning(
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
)
cleaned_batch.append(sanitize_document_for_postgres(doc))
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
if cleaned_doc.title and "\x00" in cleaned_doc.title:
logger.warning(
f"NUL characters found in document title: {cleaned_doc.title}"
)
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
)
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
"\x00", ""
)
for section in cleaned_doc.sections:
if section.link is not None:
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
if isinstance(section, TextSection) and section.text is not None:
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
return cleaned_batch
@@ -575,13 +602,10 @@ def connector_document_extraction(
# Process hierarchy nodes batch - upsert to Postgres and cache in Redis
if hierarchy_node_batch:
hierarchy_node_batch_cleaned = (
sanitize_hierarchy_nodes_for_postgres(hierarchy_node_batch)
)
with get_session_with_current_tenant() as db_session:
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=hierarchy_node_batch_cleaned,
nodes=hierarchy_node_batch,
source=db_connector.source,
commit=True,
is_connector_public=is_connector_public,
@@ -600,7 +624,7 @@ def connector_document_extraction(
)
logger.debug(
f"Persisted and cached {len(hierarchy_node_batch_cleaned)} hierarchy nodes "
f"Persisted and cached {len(hierarchy_node_batch)} hierarchy nodes "
f"for attempt={index_attempt_id}"
)

View File

@@ -1,4 +1,3 @@
import json
import re
from collections.abc import Callable
from typing import cast
@@ -46,7 +45,6 @@ from onyx.utils.timing import log_function_time
logger = setup_logger()
IMAGE_GENERATION_TOOL_NAME = "generate_image"
def create_chat_session_from_request(
@@ -424,40 +422,6 @@ def convert_chat_history_basic(
return list(reversed(trimmed_reversed))
def _build_tool_call_response_history_message(
tool_name: str,
generated_images: list[dict] | None,
tool_call_response: str | None,
) -> str:
if tool_name != IMAGE_GENERATION_TOOL_NAME:
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
if generated_images:
llm_image_context: list[dict[str, str]] = []
for image in generated_images:
file_id = image.get("file_id")
revised_prompt = image.get("revised_prompt")
if not isinstance(file_id, str):
continue
llm_image_context.append(
{
"file_id": file_id,
"revised_prompt": (
revised_prompt if isinstance(revised_prompt, str) else ""
),
}
)
if llm_image_context:
return json.dumps(llm_image_context)
if tool_call_response:
return tool_call_response
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
def convert_chat_history(
chat_history: list[ChatMessage],
files: list[ChatLoadedFile],
@@ -618,24 +582,10 @@ def convert_chat_history(
# Add TOOL_CALL_RESPONSE messages for each tool call in this turn
for tool_call in turn_tool_calls:
tool_name = tool_id_to_name_map.get(
tool_call.tool_id, "unknown"
)
tool_response_message = (
_build_tool_call_response_history_message(
tool_name=tool_name,
generated_images=tool_call.generated_images,
tool_call_response=tool_call.tool_call_response,
)
)
simple_messages.append(
ChatMessageSimple(
message=tool_response_message,
token_count=(
token_counter(tool_response_message)
if tool_name == IMAGE_GENERATION_TOOL_NAME
else 20
),
message=TOOL_CALL_RESPONSE_CROSS_MESSAGE,
token_count=20, # Tiny overestimate
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,

View File

@@ -30,7 +30,6 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchDocsResponse
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.memory import add_memory
from onyx.db.memory import update_memory_at_index
from onyx.db.memory import UserMemoryContext
@@ -657,12 +656,7 @@ def run_llm_loop(
fallback_extraction_attempted: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
# Fetch this in a short-lived session so the long-running stream loop does
# not pin a connection just to keep read state alive.
with get_session_with_current_tenant() as prompt_db_session:
default_base_system_prompt: str = get_default_base_system_prompt(
prompt_db_session
)
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
system_prompt = None
custom_agent_prompt_msg = None

View File

@@ -856,11 +856,6 @@ def handle_stream_message_objects(
reserved_tokens=reserved_token_count,
)
# Release any read transaction before entering the long-running LLM stream.
# Without this, the request-scoped session can keep a connection checked out
# for the full stream duration.
db_session.commit()
# The stream generator can resume on a different worker thread after early yields.
# Set this right before launching the LLM loop so run_in_background copies the right context.
if new_msg_req.mock_llm_response is not None:

View File

@@ -190,7 +190,7 @@ def _build_user_information_section(
if not sections:
return ""
return USER_INFORMATION_HEADER + "\n".join(sections)
return USER_INFORMATION_HEADER + "".join(sections)
def build_system_prompt(
@@ -228,21 +228,23 @@ def build_system_prompt(
system_prompt += REQUIRE_CITATION_GUIDANCE
if include_all_guidance:
tool_sections = [
TOOL_DESCRIPTION_SEARCH_GUIDANCE,
INTERNAL_SEARCH_GUIDANCE,
WEB_SEARCH_GUIDANCE.format(
system_prompt += (
TOOL_SECTION_HEADER
+ TOOL_DESCRIPTION_SEARCH_GUIDANCE
+ INTERNAL_SEARCH_GUIDANCE
+ WEB_SEARCH_GUIDANCE.format(
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
),
OPEN_URLS_GUIDANCE,
PYTHON_TOOL_GUIDANCE,
GENERATE_IMAGE_GUIDANCE,
MEMORY_GUIDANCE,
]
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_sections)
)
+ OPEN_URLS_GUIDANCE
+ PYTHON_TOOL_GUIDANCE
+ GENERATE_IMAGE_GUIDANCE
+ MEMORY_GUIDANCE
)
return system_prompt
if tools:
system_prompt += TOOL_SECTION_HEADER
has_web_search = any(isinstance(tool, WebSearchTool) for tool in tools)
has_internal_search = any(isinstance(tool, SearchTool) for tool in tools)
has_open_urls = any(isinstance(tool, OpenURLTool) for tool in tools)
@@ -252,14 +254,12 @@ def build_system_prompt(
)
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
tool_guidance_sections: list[str] = []
if has_web_search or has_internal_search or include_all_guidance:
tool_guidance_sections.append(TOOL_DESCRIPTION_SEARCH_GUIDANCE)
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
# These are not included at the Tool level because the ordering may matter.
if has_internal_search or include_all_guidance:
tool_guidance_sections.append(INTERNAL_SEARCH_GUIDANCE)
system_prompt += INTERNAL_SEARCH_GUIDANCE
if has_web_search or include_all_guidance:
site_disabled_guidance = ""
@@ -269,23 +269,20 @@ def build_system_prompt(
)
if web_search_tool and not web_search_tool.supports_site_filter:
site_disabled_guidance = WEB_SEARCH_SITE_DISABLED_GUIDANCE
tool_guidance_sections.append(
WEB_SEARCH_GUIDANCE.format(site_colon_disabled=site_disabled_guidance)
system_prompt += WEB_SEARCH_GUIDANCE.format(
site_colon_disabled=site_disabled_guidance
)
if has_open_urls or include_all_guidance:
tool_guidance_sections.append(OPEN_URLS_GUIDANCE)
system_prompt += OPEN_URLS_GUIDANCE
if has_python or include_all_guidance:
tool_guidance_sections.append(PYTHON_TOOL_GUIDANCE)
system_prompt += PYTHON_TOOL_GUIDANCE
if has_generate_image or include_all_guidance:
tool_guidance_sections.append(GENERATE_IMAGE_GUIDANCE)
system_prompt += GENERATE_IMAGE_GUIDANCE
if has_memory or include_all_guidance:
tool_guidance_sections.append(MEMORY_GUIDANCE)
if tool_guidance_sections:
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_guidance_sections)
system_prompt += MEMORY_GUIDANCE
return system_prompt

View File

@@ -210,10 +210,10 @@ AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
REQUIRE_EMAIL_VERIFICATION = (
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
)
SMTP_SERVER = os.environ.get("SMTP_SERVER") or ""
SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
SMTP_USER = os.environ.get("SMTP_USER") or ""
SMTP_PASS = os.environ.get("SMTP_PASS") or ""
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
SENDGRID_API_KEY = os.environ.get("SENDGRID_API_KEY") or ""
@@ -251,9 +251,7 @@ DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int(
os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50
)
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
OPENSEARCH_ADMIN_PASSWORD = os.environ.get(
"OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!"
)
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
USING_AWS_MANAGED_OPENSEARCH = (
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
)
@@ -284,9 +282,6 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
)
# NOTE: This effectively does nothing anymore, admins can now toggle whether
# retrieval is through OpenSearch. This value is only used as a final fallback
# in case that doesn't work for whatever reason.
# Given that the "base" config above is true, this enables whether we want to
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
# in the event we see issues with OpenSearch retrieval in our dev environments.
@@ -642,14 +637,6 @@ SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
)
# When True, group sync enumerates every Azure AD group in the tenant (expensive).
# When False (default), only groups found in site role assignments are synced.
# Can be overridden per-connector via the "exhaustive_ad_enumeration" key in
# connector_specific_config.
SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION = (
os.environ.get("SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION", "").lower() == "true"
)
BLOB_STORAGE_SIZE_THRESHOLD = int(
os.environ.get("BLOB_STORAGE_SIZE_THRESHOLD", 20 * 1024 * 1024)
)

View File

@@ -157,25 +157,6 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
# How long a queued user-file task is valid before workers discard it.
# Should be longer than the beat interval (20 s) but short enough to prevent
# indefinite queue growth. Workers drop tasks older than this without touching
# the DB, so a shorter value = faster drain of stale duplicates.
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Maximum number of tasks allowed in the user-file-processing queue before the
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
# How long a queued user-file-project-sync task remains valid.
# Should be short enough to discard stale queue entries under load while still
# allowing workers enough time to pick up new tasks.
CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Max queue depth before user-file-project-sync producers stop enqueuing.
# This applies backpressure when workers are falling behind.
USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
@@ -462,12 +443,8 @@ class OnyxRedisLocks:
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a task is already queued.
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_PROJECT_SYNC_QUEUED_PREFIX = "da_lock:user_file_project_sync_queued"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"

View File

@@ -16,22 +16,6 @@ from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
_RATE_LIMIT_REASONS = {"userRateLimitExceeded", "rateLimitExceeded"}
def _is_rate_limit_error(error: HttpError) -> bool:
"""Google sometimes returns rate-limit errors as 403 with reason
'userRateLimitExceeded' instead of 429. This helper detects both."""
if error.resp.status == 429:
return True
if error.resp.status != 403:
return False
error_details = getattr(error, "error_details", None) or []
for detail in error_details:
if isinstance(detail, dict) and detail.get("reason") in _RATE_LIMIT_REASONS:
return True
return "userRateLimitExceeded" in str(error) or "rateLimitExceeded" in str(error)
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. This is now addressed by checkpointing.
@@ -73,7 +57,7 @@ def _execute_with_retry(request: Any) -> Any:
except HttpError as error:
attempt += 1
if _is_rate_limit_error(error):
if error.resp.status == 429:
# Attempt to get 'Retry-After' from headers
retry_after = error.resp.get("Retry-After")
if retry_after:
@@ -156,16 +140,16 @@ def _execute_single_retrieval(
)
logger.error(f"Error executing request: {e}")
raise e
elif _is_rate_limit_error(e):
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.debug(f"Error executing request: {e}")
results = {}
else:
raise e
elif e.resp.status == 429:
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
else:
logger.exception("Error executing request:")
raise e

View File

@@ -1,96 +0,0 @@
"""Inverse mapping from user-facing Microsoft host URLs to the SDK's AzureEnvironment.
The office365 library's GraphClient requires an ``AzureEnvironment`` string
(e.g. ``"Global"``, ``"GCC High"``) to route requests to the correct national
cloud. Our connectors instead expose free-text ``authority_host`` and
``graph_api_host`` fields so the frontend doesn't need to know about SDK
internals.
This module bridges the gap: given the two host URLs the user configured, it
resolves the matching ``AzureEnvironment`` value (and the implied SharePoint
domain suffix) so callers can pass ``environment=…`` to ``GraphClient``.
"""
from office365.graph_client import AzureEnvironment # type: ignore[import-untyped]
from pydantic import BaseModel
from onyx.connectors.exceptions import ConnectorValidationError
class MicrosoftGraphEnvironment(BaseModel):
"""One row of the inverse mapping."""
environment: str
graph_host: str
authority_host: str
sharepoint_domain_suffix: str
_ENVIRONMENTS: list[MicrosoftGraphEnvironment] = [
MicrosoftGraphEnvironment(
environment=AzureEnvironment.Global,
graph_host="https://graph.microsoft.com",
authority_host="https://login.microsoftonline.com",
sharepoint_domain_suffix="sharepoint.com",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.USGovernmentHigh,
graph_host="https://graph.microsoft.us",
authority_host="https://login.microsoftonline.us",
sharepoint_domain_suffix="sharepoint.us",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.USGovernmentDoD,
graph_host="https://dod-graph.microsoft.us",
authority_host="https://login.microsoftonline.us",
sharepoint_domain_suffix="sharepoint.us",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.China,
graph_host="https://microsoftgraph.chinacloudapi.cn",
authority_host="https://login.chinacloudapi.cn",
sharepoint_domain_suffix="sharepoint.cn",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.Germany,
graph_host="https://graph.microsoft.de",
authority_host="https://login.microsoftonline.de",
sharepoint_domain_suffix="sharepoint.de",
),
]
_GRAPH_HOST_INDEX: dict[str, MicrosoftGraphEnvironment] = {
env.graph_host: env for env in _ENVIRONMENTS
}
def resolve_microsoft_environment(
graph_api_host: str,
authority_host: str,
) -> MicrosoftGraphEnvironment:
"""Return the ``MicrosoftGraphEnvironment`` that matches the supplied hosts.
Raises ``ConnectorValidationError`` when the combination is unknown or
internally inconsistent (e.g. a GCC-High graph host paired with a
commercial authority host).
"""
graph_api_host = graph_api_host.rstrip("/")
authority_host = authority_host.rstrip("/")
env = _GRAPH_HOST_INDEX.get(graph_api_host)
if env is None:
known = ", ".join(sorted(_GRAPH_HOST_INDEX))
raise ConnectorValidationError(
f"Unsupported Microsoft Graph API host '{graph_api_host}'. "
f"Recognised hosts: {known}"
)
if env.authority_host != authority_host:
raise ConnectorValidationError(
f"Authority host '{authority_host}' is inconsistent with "
f"graph API host '{graph_api_host}'. "
f"Expected authority host '{env.authority_host}' "
f"for the {env.environment} environment."
)
return env

View File

@@ -6,7 +6,6 @@ from typing import cast
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from onyx.access.models import ExternalAccess
@@ -168,14 +167,6 @@ class DocumentBase(BaseModel):
# list of strings.
metadata: dict[str, str | list[str]]
@field_validator("metadata", mode="before")
@classmethod
def _coerce_metadata_values(cls, v: dict[str, Any]) -> dict[str, str | list[str]]:
return {
key: [str(item) for item in val] if isinstance(val, list) else str(val)
for key, val in v.items()
}
# UTC time
doc_updated_at: datetime | None = None
chunk_count: int | None = None

View File

@@ -47,7 +47,6 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import IndexingHeartbeatInterface
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
@@ -84,11 +83,7 @@ SHARED_DOCUMENTS_MAP_REVERSE = {v: k for k, v in SHARED_DOCUMENTS_MAP.items()}
ASPX_EXTENSION = ".aspx"
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
DEFAULT_SHAREPOINT_DOMAIN_SUFFIX = "sharepoint.com"
GRAPH_API_BASE = f"{DEFAULT_GRAPH_API_HOST}/v1.0"
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
GRAPH_API_MAX_RETRIES = 5
GRAPH_API_RETRYABLE_STATUSES = frozenset({429, 500, 502, 503, 504})
@@ -147,9 +142,7 @@ class DriveItemData(BaseModel):
self.id,
ResourcePath("items", ResourcePath(self.drive_id, ResourcePath("drives"))),
)
item = DriveItem(graph_client, path)
item.set_property("id", self.id)
return item
return DriveItem(graph_client, path)
# The office365 library's ClientContext caches the access token from its
@@ -183,25 +176,6 @@ class CertificateData(BaseModel):
thumbprint: str
def _site_page_in_time_window(
page: dict[str, Any],
start: datetime | None,
end: datetime | None,
) -> bool:
"""Return True if the page's lastModifiedDateTime falls within [start, end]."""
if start is None and end is None:
return True
raw = page.get("lastModifiedDateTime")
if not raw:
return True
if not isinstance(raw, str):
raise ValueError(f"lastModifiedDateTime is not a string: {raw}")
last_modified = datetime.fromisoformat(raw.replace("Z", "+00:00"))
return (start is None or last_modified >= start) and (
end is None or last_modified <= end
)
def sleep_and_retry(
query_obj: ClientQuery, method_name: str, max_retries: int = 3
) -> Any:
@@ -247,12 +221,6 @@ class SharepointConnectorCheckpoint(ConnectorCheckpoint):
current_drive_name: str | None = None
# Drive's web_url from the API - used as raw_node_id for DRIVE hierarchy nodes
current_drive_web_url: str | None = None
# Resolved drive ID — avoids re-resolving on checkpoint resume
current_drive_id: str | None = None
# Next delta API page URL for per-page checkpointing within a drive.
# When set, Phase 3b fetches one page at a time so progress is persisted
# between pages. None means BFS path or no active delta traversal.
current_drive_delta_next_link: str | None = None
process_site_pages: bool = False
@@ -298,12 +266,10 @@ def load_certificate_from_pfx(pfx_data: bytes, password: str) -> CertificateData
def acquire_token_for_rest(
msal_app: msal.ConfidentialClientApplication,
sp_tenant_domain: str,
sharepoint_domain_suffix: str,
msal_app: msal.ConfidentialClientApplication, sp_tenant_domain: str
) -> TokenResponse:
token = msal_app.acquire_token_for_client(
scopes=[f"https://{sp_tenant_domain}.{sharepoint_domain_suffix}/.default"]
scopes=[f"https://{sp_tenant_domain}.sharepoint.com/.default"]
)
return TokenResponse.from_json(token)
@@ -418,13 +384,12 @@ def _download_via_graph_api(
drive_id: str,
item_id: str,
bytes_allowed: int,
graph_api_base: str,
) -> bytes:
"""Download a drive item via the Graph API /content endpoint with a byte cap.
Raises SizeCapExceeded if the cap is exceeded.
"""
url = f"{graph_api_base}/drives/{drive_id}/items/{item_id}/content"
url = f"{GRAPH_API_BASE}/drives/{drive_id}/items/{item_id}/content"
headers = {"Authorization": f"Bearer {access_token}"}
with requests.get(
url, headers=headers, stream=True, timeout=REQUEST_TIMEOUT_SECONDS
@@ -445,7 +410,6 @@ def _convert_driveitem_to_document_with_permissions(
drive_name: str,
ctx: ClientContext | None,
graph_client: GraphClient,
graph_api_base: str,
include_permissions: bool = False,
parent_hierarchy_raw_node_id: str | None = None,
access_token: str | None = None,
@@ -502,7 +466,6 @@ def _convert_driveitem_to_document_with_permissions(
driveitem.drive_id,
driveitem.id,
SHAREPOINT_CONNECTOR_SIZE_THRESHOLD,
graph_api_base=graph_api_base,
)
except SizeCapExceeded:
logger.warning(
@@ -822,9 +785,6 @@ class SharepointConnector(
sites: list[str] = [],
include_site_pages: bool = True,
include_site_documents: bool = True,
authority_host: str = DEFAULT_AUTHORITY_HOST,
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
sharepoint_domain_suffix: str = DEFAULT_SHAREPOINT_DOMAIN_SUFFIX,
) -> None:
self.batch_size = batch_size
self.sites = list(sites)
@@ -841,20 +801,6 @@ class SharepointConnector(
self._cached_rest_ctx_url: str | None = None
self._cached_rest_ctx_created_at: float = 0.0
resolved_env = resolve_microsoft_environment(graph_api_host, authority_host)
self._azure_environment = resolved_env.environment
self.authority_host = resolved_env.authority_host
self.graph_api_host = resolved_env.graph_host
self.graph_api_base = f"{self.graph_api_host}/v1.0"
self.sharepoint_domain_suffix = resolved_env.sharepoint_domain_suffix
if sharepoint_domain_suffix != resolved_env.sharepoint_domain_suffix:
logger.warning(
f"Configured sharepoint_domain_suffix '{sharepoint_domain_suffix}' "
f"differs from the expected suffix '{resolved_env.sharepoint_domain_suffix}' "
f"for the {resolved_env.environment} environment. "
f"Using '{resolved_env.sharepoint_domain_suffix}'."
)
def validate_connector_settings(self) -> None:
# Validate that at least one content type is enabled
if not self.include_site_documents and not self.include_site_pages:
@@ -910,9 +856,8 @@ class SharepointConnector(
msal_app = self.msal_app
sp_tenant_domain = self.sp_tenant_domain
sp_domain_suffix = self.sharepoint_domain_suffix
self._cached_rest_ctx = ClientContext(site_url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
self._cached_rest_ctx_url = site_url
self._cached_rest_ctx_created_at = time.monotonic()
@@ -1172,36 +1117,76 @@ class SharepointConnector(
site_descriptor: SiteDescriptor,
start: datetime | None = None,
end: datetime | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Yield SharePoint site pages (.aspx files) one at a time.
) -> list[dict[str, Any]]:
"""Fetch SharePoint site pages (.aspx files) using the SharePoint Pages API."""
Pages are fetched via the Graph Pages API and yielded lazily as each
API page arrives, so memory stays bounded regardless of total page count.
Time-window filtering is applied per-item before yielding.
"""
# Get the site to extract the site ID
site = self.graph_client.sites.get_by_url(site_descriptor.url)
site.execute_query()
site.execute_query() # Execute the query to actually fetch the data
site_id = site.id
page_url: str | None = (
f"{self.graph_api_base}/sites/{site_id}" f"/pages/microsoft.graph.sitePage"
# Get the token acquisition function from the GraphClient
token_data = self._acquire_token()
access_token = token_data.get("access_token")
if not access_token:
raise RuntimeError("Failed to acquire access token")
# Construct the SharePoint Pages API endpoint
# Using API directly, since the Graph Client doesn't support the Pages API
pages_endpoint = f"https://graph.microsoft.com/v1.0/sites/{site_id}/pages/microsoft.graph.sitePage"
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
# Add expand parameter to get canvas layout content
params = {"$expand": "canvasLayout"}
response = requests.get(
pages_endpoint,
headers=headers,
params=params,
timeout=REQUEST_TIMEOUT_SECONDS,
)
params: dict[str, str] | None = {"$expand": "canvasLayout"}
total_yielded = 0
response.raise_for_status()
pages_data = response.json()
all_pages = pages_data.get("value", [])
while page_url:
data = self._graph_api_get_json(page_url, params)
params = None # nextLink already embeds query params
# Handle pagination if there are more pages
# TODO: This accumulates all pages in memory and can be heavy on large tenants.
# We should process each page incrementally to avoid unbounded growth.
while "@odata.nextLink" in pages_data:
next_url = pages_data["@odata.nextLink"]
response = requests.get(
next_url, headers=headers, timeout=REQUEST_TIMEOUT_SECONDS
)
response.raise_for_status()
pages_data = response.json()
all_pages.extend(pages_data.get("value", []))
for page in data.get("value", []):
if not _site_page_in_time_window(page, start, end):
continue
total_yielded += 1
yield page
logger.debug(f"Found {len(all_pages)} site pages in {site_descriptor.url}")
page_url = data.get("@odata.nextLink")
# Filter pages based on time window if specified
if start is not None or end is not None:
filtered_pages: list[dict[str, Any]] = []
for page in all_pages:
page_modified = page.get("lastModifiedDateTime")
if page_modified:
if isinstance(page_modified, str):
page_modified = datetime.fromisoformat(
page_modified.replace("Z", "+00:00")
)
logger.debug(f"Yielded {total_yielded} site pages for {site_descriptor.url}")
if start is not None and page_modified < start:
continue
if end is not None and page_modified > end:
continue
filtered_pages.append(page)
all_pages = filtered_pages
return all_pages
def _acquire_token(self) -> dict[str, Any]:
"""
@@ -1211,7 +1196,7 @@ class SharepointConnector(
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=[f"{self.graph_api_host}/.default"]
scopes=["https://graph.microsoft.com/.default"]
)
return token
@@ -1284,10 +1269,9 @@ class SharepointConnector(
Performs BFS folder traversal manually, fetching one page of children
at a time so that memory usage stays bounded regardless of drive size.
"""
base = f"{self.graph_api_base}/drives/{drive_id}"
base = f"{GRAPH_API_BASE}/drives/{drive_id}"
if folder_path:
encoded_path = quote(folder_path, safe="/")
start_url = f"{base}/root:/{encoded_path}:/children"
start_url = f"{base}/root:/{folder_path}:/children"
else:
start_url = f"{base}/root/children"
@@ -1345,7 +1329,7 @@ class SharepointConnector(
"""
use_timestamp_token = start is not None and start > _EPOCH
initial_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta"
initial_url = f"{GRAPH_API_BASE}/drives/{drive_id}/root/delta"
if use_timestamp_token:
assert start is not None # mypy
token = quote(start.isoformat(timespec="seconds"))
@@ -1391,7 +1375,7 @@ class SharepointConnector(
drive_id,
)
yield from self._iter_delta_pages(
initial_url=f"{self.graph_api_base}/drives/{drive_id}/root/delta",
initial_url=f"{GRAPH_API_BASE}/drives/{drive_id}/root/delta",
drive_id=drive_id,
start=start,
end=end,
@@ -1422,87 +1406,6 @@ class SharepointConnector(
if not page_url:
break
def _build_delta_start_url(
self,
drive_id: str,
start: datetime | None = None,
page_size: int = 200,
) -> str:
"""Build the initial delta API URL with query parameters embedded.
Embeds ``$top`` (and optionally a timestamp ``token``) directly in the
URL so that the returned string is fully self-contained and can be
stored in a checkpoint without needing a separate params dict.
"""
base_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta"
params = [f"$top={page_size}"]
if start is not None and start > _EPOCH:
token = quote(start.isoformat(timespec="seconds"))
params.append(f"token={token}")
return f"{base_url}?{'&'.join(params)}"
def _fetch_one_delta_page(
self,
page_url: str,
drive_id: str,
start: datetime | None = None,
end: datetime | None = None,
page_size: int = 200,
) -> tuple[list[DriveItemData], str | None]:
"""Fetch a single page of delta API results.
Returns ``(items, next_page_url)``. *next_page_url* is ``None`` when
the delta enumeration is complete (deltaLink with no nextLink).
On 410 Gone (expired token) returns ``([], full_resync_url)`` so
the caller can store the resync URL in the checkpoint and retry on
the next cycle.
"""
try:
data = self._graph_api_get_json(page_url)
except requests.HTTPError as e:
if e.response is not None and e.response.status_code == 410:
logger.warning(
"Delta token expired (410 Gone) for drive '%s'. "
"Will restart with full delta enumeration.",
drive_id,
)
full_url = (
f"{self.graph_api_base}/drives/{drive_id}/root/delta"
f"?$top={page_size}"
)
return [], full_url
raise
items: list[DriveItemData] = []
for item in data.get("value", []):
if "folder" in item or "deleted" in item:
continue
if start is not None or end is not None:
raw_ts = item.get("lastModifiedDateTime")
if raw_ts:
mod_dt = datetime.fromisoformat(raw_ts.replace("Z", "+00:00"))
if start is not None and mod_dt < start:
continue
if end is not None and mod_dt > end:
continue
items.append(DriveItemData.from_graph_json(item))
next_url = data.get("@odata.nextLink")
if next_url:
return items, next_url
return items, None
@staticmethod
def _clear_drive_checkpoint_state(
checkpoint: "SharepointConnectorCheckpoint",
) -> None:
"""Reset all drive-level fields in the checkpoint."""
checkpoint.current_drive_name = None
checkpoint.current_drive_id = None
checkpoint.current_drive_web_url = None
checkpoint.current_drive_delta_next_link = None
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
site_descriptors = self.site_descriptors or self.fetch_sites()
@@ -1589,7 +1492,7 @@ class SharepointConnector(
sp_private_key = credentials.get("sp_private_key")
sp_certificate_password = credentials.get("sp_certificate_password")
authority_url = f"{self.authority_host}/{sp_directory_id}"
authority_url = f"https://login.microsoftonline.com/{sp_directory_id}"
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
logger.info("Using certificate authentication")
@@ -1605,7 +1508,6 @@ class SharepointConnector(
if certificate_data is None:
raise RuntimeError("Failed to load certificate")
logger.info(f"Creating MSAL app with authority url {authority_url}")
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=sp_client_id,
@@ -1631,15 +1533,13 @@ class SharepointConnector(
raise ConnectorValidationError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=[f"{self.graph_api_host}/.default"]
scopes=["https://graph.microsoft.com/.default"]
)
if token is None:
raise ConnectorValidationError("Failed to acquire token for graph")
return token
self._graph_client = GraphClient(
_acquire_token_for_graph, environment=self._azure_environment
)
self._graph_client = GraphClient(_acquire_token_for_graph)
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
org = self.graph_client.organization.get().execute_query()
if not org or len(org) == 0:
@@ -1947,13 +1847,14 @@ class SharepointConnector(
# Return checkpoint to allow persistence after drive initialization
return checkpoint
# Phase 3a: Initialize the next drive for processing
# Phase 3: Process documents from current drive
if (
checkpoint.current_site_descriptor
and checkpoint.cached_drive_names
and len(checkpoint.cached_drive_names) > 0
and checkpoint.current_drive_name is None
):
checkpoint.current_drive_name = checkpoint.cached_drive_names.popleft()
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
@@ -1961,8 +1862,7 @@ class SharepointConnector(
site_descriptor = checkpoint.current_site_descriptor
logger.info(
f"Processing drive '{checkpoint.current_drive_name}' "
f"in site: {site_descriptor.url}"
f"Processing drive '{checkpoint.current_drive_name}' in site: {site_descriptor.url}"
)
logger.debug(f"Time range: {start_dt} to {end_dt}")
@@ -1971,35 +1871,35 @@ class SharepointConnector(
logger.warning("Current drive name is None, skipping")
return checkpoint
driveitems: Iterable[DriveItemData] = iter(())
drive_web_url: str | None = None
try:
logger.info(
f"Fetching drive items for drive name: {current_drive_name}"
)
result = self._resolve_drive(site_descriptor, current_drive_name)
if result is None:
logger.warning(f"Drive '{current_drive_name}' not found, skipping")
self._clear_drive_checkpoint_state(checkpoint)
return checkpoint
drive_id, drive_web_url = result
checkpoint.current_drive_id = drive_id
checkpoint.current_drive_web_url = drive_web_url
if result is not None:
drive_id, drive_web_url = result
driveitems = self._get_drive_items_for_drive_id(
site_descriptor, drive_id, start_dt, end_dt
)
checkpoint.current_drive_web_url = drive_web_url
except Exception as e:
logger.error(
f"Failed to retrieve items from drive '{current_drive_name}' "
f"in site: {site_descriptor.url}: {e}"
f"Failed to retrieve items from drive '{current_drive_name}' in site: {site_descriptor.url}: {e}"
)
yield _create_entity_failure(
f"{site_descriptor.url}|{current_drive_name}",
f"Failed to access drive '{current_drive_name}' "
f"in site '{site_descriptor.url}': {str(e)}",
f"Failed to access drive '{current_drive_name}' in site '{site_descriptor.url}': {str(e)}",
(start_dt, end_dt),
e,
)
self._clear_drive_checkpoint_state(checkpoint)
checkpoint.current_drive_name = None
checkpoint.current_drive_web_url = None
return checkpoint
display_drive_name = SHARED_DOCUMENTS_MAP.get(
# Normalize drive name (e.g., "Documents" -> "Shared Documents")
current_drive_name = SHARED_DOCUMENTS_MAP.get(
current_drive_name, current_drive_name
)
@@ -2007,74 +1907,10 @@ class SharepointConnector(
yield from self._yield_drive_hierarchy_node(
site_descriptor.url,
drive_web_url,
display_drive_name,
current_drive_name,
checkpoint,
)
# For non-folder-scoped drives, use delta API with per-page
# checkpointing. Build the initial URL and fall through to 3b.
if not site_descriptor.folder_path:
checkpoint.current_drive_delta_next_link = self._build_delta_start_url(
drive_id, start_dt
)
# else: BFS path — delta_next_link stays None;
# Phase 3b will use _iter_drive_items_paged.
# Phase 3b: Process items from the current drive
if (
checkpoint.current_site_descriptor
and checkpoint.current_drive_name is not None
and checkpoint.current_drive_id is not None
):
site_descriptor = checkpoint.current_site_descriptor
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
end_dt = datetime.fromtimestamp(end, tz=timezone.utc)
current_drive_name = SHARED_DOCUMENTS_MAP.get(
checkpoint.current_drive_name, checkpoint.current_drive_name
)
drive_web_url = checkpoint.current_drive_web_url
# --- determine item source ---
driveitems: Iterable[DriveItemData]
has_more_delta_pages = False
if checkpoint.current_drive_delta_next_link:
# Delta path: fetch one page at a time for checkpointing
try:
page_items, next_url = self._fetch_one_delta_page(
page_url=checkpoint.current_drive_delta_next_link,
drive_id=checkpoint.current_drive_id,
start=start_dt,
end=end_dt,
)
except Exception as e:
logger.error(
f"Failed to fetch delta page for drive "
f"'{current_drive_name}': {e}"
)
yield _create_entity_failure(
f"{site_descriptor.url}|{current_drive_name}",
f"Failed to fetch delta page for drive "
f"'{current_drive_name}': {str(e)}",
(start_dt, end_dt),
e,
)
self._clear_drive_checkpoint_state(checkpoint)
return checkpoint
driveitems = page_items
has_more_delta_pages = next_url is not None
if next_url:
checkpoint.current_drive_delta_next_link = next_url
else:
# BFS path (folder-scoped): process all items at once
driveitems = self._iter_drive_items_paged(
drive_id=checkpoint.current_drive_id,
folder_path=site_descriptor.folder_path,
start=start_dt,
end=end_dt,
)
item_count = 0
for driveitem in driveitems:
item_count += 1
@@ -2116,6 +1952,8 @@ class SharepointConnector(
if include_permissions:
ctx = self._create_rest_client_context(site_descriptor.url)
# Re-acquire token in case it expired during a long traversal
# MSAL has a cache that returns the same token while still valid.
access_token = self._get_graph_access_token()
doc_or_failure = _convert_driveitem_to_document_with_permissions(
driveitem,
@@ -2124,7 +1962,6 @@ class SharepointConnector(
self.graph_client,
include_permissions=include_permissions,
parent_hierarchy_raw_node_id=parent_hierarchy_url,
graph_api_base=self.graph_api_base,
access_token=access_token,
)
@@ -2151,11 +1988,8 @@ class SharepointConnector(
)
logger.info(f"Processed {item_count} items in drive '{current_drive_name}'")
if has_more_delta_pages:
return checkpoint
self._clear_drive_checkpoint_state(checkpoint)
checkpoint.current_drive_name = None
checkpoint.current_drive_web_url = None
# Phase 4: Progression logic - determine next step
# If we have more drives in current site, continue with current site

View File

@@ -11,7 +11,6 @@ from dateutil import parser
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
@@ -259,21 +258,3 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
slim_doc_batch = []
if slim_doc_batch:
yield slim_doc_batch
def validate_connector_settings(self) -> None:
"""
Very basic validation, we could do more here
"""
if not self.base_url.startswith("https://") and not self.base_url.startswith(
"http://"
):
raise ConnectorValidationError(
"Base URL must start with https:// or http://"
)
try:
get_all_post_ids(self.slab_bot_token)
except ConnectorMissingCredentialError:
raise
except Exception as e:
raise ConnectorValidationError(f"Failed to fetch posts from Slab: {e}")

View File

@@ -23,7 +23,6 @@ from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
@@ -51,15 +50,12 @@ class TeamsCheckpoint(ConnectorCheckpoint):
todo_team_ids: list[str] | None = None
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
class TeamsConnector(
CheckpointedConnectorWithPermSync[TeamsCheckpoint],
SlimConnectorWithPermSync,
):
MAX_WORKERS = 10
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
def __init__(
self,
@@ -67,19 +63,12 @@ class TeamsConnector(
# are not necessarily guaranteed to be unique
teams: list[str] = [],
max_workers: int = MAX_WORKERS,
authority_host: str = DEFAULT_AUTHORITY_HOST,
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
) -> None:
self.graph_client: GraphClient | None = None
self.msal_app: msal.ConfidentialClientApplication | None = None
self.max_workers = max_workers
self.requested_team_list: list[str] = teams
resolved_env = resolve_microsoft_environment(graph_api_host, authority_host)
self._azure_environment = resolved_env.environment
self.authority_host = resolved_env.authority_host
self.graph_api_host = resolved_env.graph_host
# impls for BaseConnector
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
@@ -87,7 +76,7 @@ class TeamsConnector(
teams_client_secret = credentials["teams_client_secret"]
teams_directory_id = credentials["teams_directory_id"]
authority_url = f"{self.authority_host}/{teams_directory_id}"
authority_url = f"{TeamsConnector.AUTHORITY_URL_PREFIX}{teams_directory_id}"
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=teams_client_id,
@@ -102,7 +91,7 @@ class TeamsConnector(
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=[f"{self.graph_api_host}/.default"]
scopes=["https://graph.microsoft.com/.default"]
)
if not isinstance(token, dict):
@@ -110,9 +99,7 @@ class TeamsConnector(
return token
self.graph_client = GraphClient(
_acquire_token_func, environment=self._azure_environment
)
self.graph_client = GraphClient(_acquire_token_func)
return None
def validate_connector_settings(self) -> None:

View File

@@ -32,7 +32,6 @@ from onyx.context.search.federated.slack_search_utils import should_include_mess
from onyx.context.search.models import ChunkIndexRequest
from onyx.context.search.models import InferenceChunk
from onyx.db.document import DocumentSource
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.document_index_utils import (
get_multipass_config,
@@ -906,15 +905,13 @@ def convert_slack_score(slack_score: float) -> float:
def slack_retrieval(
query: ChunkIndexRequest,
access_token: str,
db_session: Session | None = None,
db_session: Session,
connector: FederatedConnectorDetail | None = None, # noqa: ARG001
entities: dict[str, Any] | None = None,
limit: int | None = None,
slack_event_context: SlackContext | None = None,
bot_token: str | None = None, # Add bot token parameter
team_id: str | None = None,
# Pre-fetched data — when provided, avoids DB query (no session needed)
search_settings: SearchSettings | None = None,
) -> list[InferenceChunk]:
"""
Main entry point for Slack federated search with entity filtering.
@@ -928,7 +925,7 @@ def slack_retrieval(
Args:
query: Search query object
access_token: User OAuth access token
db_session: Database session (optional if search_settings provided)
db_session: Database session
connector: Federated connector detail (unused, kept for backwards compat)
entities: Connector-level config (entity filtering configuration)
limit: Maximum number of results
@@ -1156,10 +1153,7 @@ def slack_retrieval(
# chunk index docs into doc aware chunks
# a single index doc can get split into multiple chunks
if search_settings is None:
if db_session is None:
raise ValueError("Either db_session or search_settings must be provided")
search_settings = get_current_search_settings(db_session)
search_settings = get_current_search_settings(db_session)
embedder = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
)

View File

@@ -18,10 +18,8 @@ from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.llm.interfaces import LLM
from onyx.natural_language_processing.english_stopwords import strip_stopwords
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.secondary_llm_flows.source_filter import extract_source_filter
from onyx.secondary_llm_flows.time_filter import extract_time_filter
from onyx.utils.logger import setup_logger
@@ -43,7 +41,7 @@ def _build_index_filters(
user_file_ids: list[UUID] | None,
persona_document_sets: list[str] | None,
persona_time_cutoff: datetime | None,
db_session: Session | None = None,
db_session: Session,
auto_detect_filters: bool = False,
query: str | None = None,
llm: LLM | None = None,
@@ -51,19 +49,18 @@ def _build_index_filters(
# Assistant knowledge filters
attached_document_ids: list[str] | None = None,
hierarchy_node_ids: list[int] | None = None,
# Pre-fetched ACL filters (skips DB query when provided)
acl_filters: list[str] | None = None,
) -> IndexFilters:
if auto_detect_filters and (llm is None or query is None):
raise RuntimeError("LLM and query are required for auto detect filters")
base_filters = user_provided_filters or BaseFilters()
document_set_filter = (
base_filters.document_set
if base_filters.document_set is not None
else persona_document_sets
)
if (
user_provided_filters
and user_provided_filters.document_set is None
and persona_document_sets is not None
):
base_filters.document_set = persona_document_sets
time_filter = base_filters.time_cutoff or persona_time_cutoff
source_filter = base_filters.source_type
@@ -106,20 +103,15 @@ def _build_index_filters(
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
logger.debug("Added USER_FILE to source_filter for user knowledge search")
if bypass_acl:
user_acl_filters = None
elif acl_filters is not None:
user_acl_filters = acl_filters
else:
if db_session is None:
raise ValueError("Either db_session or acl_filters must be provided")
user_acl_filters = build_access_filters_for_user(user, db_session)
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
final_filters = IndexFilters(
user_file_ids=user_file_ids,
project_id=project_id,
source_type=source_filter,
document_set=document_set_filter,
document_set=persona_document_sets,
time_cutoff=time_filter,
tags=base_filters.tags,
access_control_list=user_acl_filters,
@@ -260,15 +252,11 @@ def search_pipeline(
user: User,
# Used for default filters and settings
persona: Persona | None,
db_session: Session | None = None,
db_session: Session,
auto_detect_filters: bool = False,
llm: LLM | None = None,
# If a project ID is provided, it will be exclusively scoped to that project
project_id: int | None = None,
# Pre-fetched data — when provided, avoids DB queries (no session needed)
acl_filters: list[str] | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
) -> list[InferenceChunk]:
user_uploaded_persona_files: list[UUID] | None = (
[user_file.id for user_file in persona.user_files] if persona else None
@@ -309,7 +297,6 @@ def search_pipeline(
bypass_acl=chunk_search_request.bypass_acl,
attached_document_ids=attached_document_ids,
hierarchy_node_ids=hierarchy_node_ids,
acl_filters=acl_filters,
)
query_keywords = strip_stopwords(chunk_search_request.query)
@@ -328,8 +315,6 @@ def search_pipeline(
user_id=user.id if user else None,
document_index=document_index,
db_session=db_session,
embedding_model=embedding_model,
prefetched_federated_retrieval_infos=prefetched_federated_retrieval_infos,
)
# For some specific connectors like Salesforce, a user that has access to an object doesn't mean

View File

@@ -14,11 +14,9 @@ from onyx.context.search.utils import get_query_embedding
from onyx.context.search.utils import inference_section_from_chunks
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.federated_connectors.federated_retrieval import (
get_federated_retrieval_functions,
)
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -52,14 +50,9 @@ def combine_retrieval_results(
def _embed_and_search(
query_request: ChunkIndexRequest,
document_index: DocumentIndex,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
db_session: Session,
) -> list[InferenceChunk]:
query_embedding = get_query_embedding(
query_request.query,
db_session=db_session,
embedding_model=embedding_model,
)
query_embedding = get_query_embedding(query_request.query, db_session)
hybrid_alpha = query_request.hybrid_alpha or HYBRID_ALPHA
@@ -85,9 +78,7 @@ def search_chunks(
query_request: ChunkIndexRequest,
user_id: UUID | None,
document_index: DocumentIndex,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
db_session: Session,
) -> list[InferenceChunk]:
run_queries: list[tuple[Callable, tuple]] = []
@@ -97,22 +88,14 @@ def search_chunks(
else None
)
# Federated retrieval — use pre-fetched if available, otherwise query DB
if prefetched_federated_retrieval_infos is not None:
federated_retrieval_infos = prefetched_federated_retrieval_infos
else:
if db_session is None:
raise ValueError(
"Either db_session or prefetched_federated_retrieval_infos "
"must be provided"
)
federated_retrieval_infos = get_federated_retrieval_functions(
db_session=db_session,
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
# Federated retrieval
federated_retrieval_infos = get_federated_retrieval_functions(
db_session=db_session,
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
federated_sources = set(
federated_retrieval_info.source.to_non_federated_source()
@@ -131,10 +114,7 @@ def search_chunks(
if normal_search_enabled:
run_queries.append(
(
_embed_and_search,
(query_request, document_index, db_session, embedding_model),
)
(_embed_and_search, (query_request, document_index, db_session))
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)

View File

@@ -64,34 +64,23 @@ def inference_section_from_single_chunk(
)
def get_query_embeddings(
queries: list[str],
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> list[Embedding]:
if embedding_model is None:
if db_session is None:
raise ValueError("Either db_session or embedding_model must be provided")
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
search_settings = get_current_search_settings(db_session)
query_embedding = embedding_model.encode(queries, text_type=EmbedTextType.QUERY)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
return query_embedding
@log_function_time(print_only=True, debug_only=True)
def get_query_embedding(
query: str,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> Embedding:
return get_query_embeddings(
[query], db_session=db_session, embedding_model=embedding_model
)[0]
def get_query_embedding(query: str, db_session: Session) -> Embedding:
return get_query_embeddings([query], db_session)[0]
def convert_inference_sections_to_search_docs(

View File

@@ -4,7 +4,6 @@ from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.api_key import ApiKeyDescriptor
@@ -55,7 +54,6 @@ async def fetch_user_for_api_key(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
.options(selectinload(User.memories))
)

View File

@@ -13,7 +13,6 @@ from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
@@ -98,11 +97,6 @@ async def get_user_count(only_admin_users: bool = False) -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def _get_user(self, statement: Select) -> UP | None:
statement = statement.options(selectinload(User.memories))
results = await self.session.execute(statement)
return results.unique().scalar_one_or_none()
async def create(
self,
create_dict: Dict[str, Any],

View File

@@ -1,21 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import CodeInterpreterServer
def fetch_code_interpreter_server(
db_session: Session,
) -> CodeInterpreterServer:
server = db_session.scalars(select(CodeInterpreterServer)).one()
return server
def update_code_interpreter_server_enabled(
db_session: Session,
enabled: bool,
) -> CodeInterpreterServer:
server = db_session.scalars(select(CodeInterpreterServer)).one()
server.server_enabled = enabled
db_session.commit()
return server

View File

@@ -116,15 +116,12 @@ def get_connector_credential_pairs_for_user(
order_by_desc: bool = False,
source: DocumentSource | None = None,
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
defer_connector_config: bool = False,
) -> list[ConnectorCredentialPair]:
"""Get connector credential pairs for a user.
Args:
processing_mode: Filter by processing mode. Defaults to REGULAR to hide
FILE_SYSTEM connectors from standard admin UI. Pass None to get all.
defer_connector_config: If True, skips loading Connector.connector_specific_config
to avoid fetching large JSONB blobs when they aren't needed.
"""
if eager_load_user:
assert (
@@ -133,10 +130,7 @@ def get_connector_credential_pairs_for_user(
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
connector_load = selectinload(ConnectorCredentialPair.connector)
if defer_connector_config:
connector_load = connector_load.defer(Connector.connector_specific_config)
stmt = stmt.options(connector_load)
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
@@ -176,7 +170,6 @@ def get_connector_credential_pairs_for_user_parallel(
order_by_desc: bool = False,
source: DocumentSource | None = None,
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
defer_connector_config: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_with_current_tenant() as db_session:
return get_connector_credential_pairs_for_user(
@@ -190,7 +183,6 @@ def get_connector_credential_pairs_for_user_parallel(
order_by_desc=order_by_desc,
source=source,
processing_mode=processing_mode,
defer_connector_config=defer_connector_config,
)

View File

@@ -554,19 +554,10 @@ def fetch_all_document_sets_for_user(
stmt = (
select(DocumentSetDBModel)
.distinct()
.options(
selectinload(DocumentSetDBModel.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSetDBModel.users),
selectinload(DocumentSetDBModel.groups),
selectinload(DocumentSetDBModel.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
)
.options(selectinload(DocumentSetDBModel.federated_connectors))
)
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def fetch_documents_for_document_set_paginated(

View File

@@ -1,102 +1,11 @@
from sqlalchemy import text
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import SqlEngine
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
def get_schemas_needing_migration(
tenant_schemas: list[str], head_rev: str
) -> list[str]:
"""Return only schemas whose current alembic version is not at head.
Uses a server-side PL/pgSQL loop to collect each schema's alembic version
into a temp table one at a time. This avoids building a massive UNION ALL
query (which locks the DB and times out at 17k+ schemas) and instead
acquires locks sequentially, one schema per iteration.
"""
if not tenant_schemas:
return []
engine = SqlEngine.get_engine()
with engine.connect() as conn:
# Populate a temp input table with exactly the schemas we care about.
# The DO block reads from this table so it only iterates the requested
# schemas instead of every tenant_% schema in the database.
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
conn.execute(text("CREATE TEMP TABLE _tenant_schemas_input (schema_name text)"))
conn.execute(
text(
"INSERT INTO _tenant_schemas_input (schema_name) "
"SELECT unnest(CAST(:schemas AS text[]))"
),
{"schemas": tenant_schemas},
)
conn.execute(
text(
"CREATE TEMP TABLE _alembic_version_snapshot "
"(schema_name text, version_num text)"
)
)
conn.execute(
text(
"""
DO $$
DECLARE
s text;
schemas text[];
BEGIN
SELECT array_agg(schema_name) INTO schemas
FROM _tenant_schemas_input;
IF schemas IS NULL THEN
RAISE NOTICE 'No tenant schemas found.';
RETURN;
END IF;
FOREACH s IN ARRAY schemas LOOP
BEGIN
EXECUTE format(
'INSERT INTO _alembic_version_snapshot
SELECT %L, version_num FROM %I.alembic_version',
s, s
);
EXCEPTION
-- undefined_table: schema exists but has no alembic_version
-- table yet (new tenant, not yet migrated).
-- invalid_schema_name: tenant is registered but its
-- PostgreSQL schema does not exist yet (e.g. provisioning
-- incomplete). Both cases mean no version is available and
-- the schema will be included in the migration list.
WHEN undefined_table THEN NULL;
WHEN invalid_schema_name THEN NULL;
END;
END LOOP;
END;
$$
"""
)
)
rows = conn.execute(
text("SELECT schema_name, version_num FROM _alembic_version_snapshot")
)
version_by_schema = {row[0]: row[1] for row in rows}
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
# Schemas missing from the snapshot have no alembic_version table yet and
# also need migration. version_by_schema.get(s) returns None for those,
# and None != head_rev, so they are included automatically.
return [s for s in tenant_schemas if version_by_schema.get(s) != head_rev]
def get_all_tenant_ids() -> list[str]:
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""

View File

@@ -287,7 +287,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user"
"Credential", back_populates="user", lazy="joined"
)
chat_sessions: Mapped[list["ChatSession"]] = relationship(
"ChatSession", back_populates="user"
@@ -321,6 +321,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
"Memory",
back_populates="user",
cascade="all, delete-orphan",
lazy="selectin",
order_by="desc(Memory.id)",
)
oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
@@ -4939,12 +4940,6 @@ class ScimUserMapping(Base):
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
scim_username: Mapped[str | None] = mapped_column(String, nullable=True)
department: Mapped[str | None] = mapped_column(String, nullable=True)
manager: Mapped[str | None] = mapped_column(String, nullable=True)
given_name: Mapped[str | None] = mapped_column(String, nullable=True)
family_name: Mapped[str | None] = mapped_column(String, nullable=True)
scim_emails_json: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
@@ -4983,12 +4978,3 @@ class ScimGroupMapping(Base):
user_group: Mapped[UserGroup] = relationship(
"UserGroup", foreign_keys=[user_group_id]
)
class CodeInterpreterServer(Base):
"""Details about the code interpreter server"""
__tablename__ = "code_interpreter_server"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)

View File

@@ -8,7 +8,6 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.pat import build_displayable_pat
@@ -32,59 +31,53 @@ async def fetch_user_for_pat(
NOTE: This is async since it's used during auth (which is necessarily async due to FastAPI Users).
NOTE: Expired includes both naturally expired and user-revoked tokens (revocation sets expires_at=NOW()).
Uses select(User) as primary entity so that joined-eager relationships (e.g. oauth_accounts)
are loaded correctly — matching the pattern in fetch_user_for_api_key.
"""
# Single joined query with all filters pushed to database
now = datetime.now(timezone.utc)
user = await async_db_session.scalar(
select(User)
.join(PersonalAccessToken, PersonalAccessToken.user_id == User.id)
result = await async_db_session.execute(
select(PersonalAccessToken, User)
.join(User, PersonalAccessToken.user_id == User.id)
.where(PersonalAccessToken.hashed_token == hashed_token)
.where(User.is_active) # type: ignore
.where(
(PersonalAccessToken.expires_at.is_(None))
| (PersonalAccessToken.expires_at > now)
)
.options(selectinload(User.memories))
.limit(1)
)
if not user:
row = result.first()
if not row:
return None
_schedule_pat_last_used_update(hashed_token, now)
return user
pat, user = row
# Throttle last_used_at updates to reduce DB load (5-minute granularity sufficient for auditing)
# For request-level auditing, use application logs or a dedicated audit table
should_update = (
pat.last_used_at is None or (now - pat.last_used_at).total_seconds() > 300
)
def _schedule_pat_last_used_update(hashed_token: str, now: datetime) -> None:
"""Fire-and-forget update of last_used_at, throttled to 5-minute granularity."""
async def _update() -> None:
try:
tenant_id = get_current_tenant_id()
async with get_async_session_context_manager(tenant_id) as session:
pat = await session.scalar(
select(PersonalAccessToken).where(
PersonalAccessToken.hashed_token == hashed_token
if should_update:
# Update in separate session to avoid transaction coupling (fire-and-forget)
async def _update_last_used() -> None:
try:
tenant_id = get_current_tenant_id()
async with get_async_session_context_manager(
tenant_id
) as separate_session:
await separate_session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.hashed_token == hashed_token)
.values(last_used_at=now)
)
)
if not pat:
return
if (
pat.last_used_at is not None
and (now - pat.last_used_at).total_seconds() <= 300
):
return
await session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.hashed_token == hashed_token)
.values(last_used_at=now)
)
await session.commit()
except Exception as e:
logger.warning(f"Failed to update last_used_at for PAT: {e}")
await separate_session.commit()
except Exception as e:
logger.warning(f"Failed to update last_used_at for PAT: {e}")
asyncio.create_task(_update())
asyncio.create_task(_update_last_used())
return user
def create_pat(

View File

@@ -28,7 +28,6 @@ from onyx.db.document_access import get_accessible_documents_by_ids
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import HierarchyNode
from onyx.db.models import Persona
from onyx.db.models import Persona__User
@@ -421,16 +420,9 @@ def get_minimal_persona_snapshots_for_user(
stmt = stmt.options(
selectinload(Persona.tools),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.connector),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
@@ -461,16 +453,7 @@ def get_persona_snapshots_for_user(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
@@ -567,16 +550,9 @@ def get_minimal_persona_snapshots_paginated(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.connector),
selectinload(Persona.user),
)
@@ -635,16 +611,7 @@ def get_persona_snapshots_paginated(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),

View File

@@ -2,7 +2,6 @@ import random
from datetime import datetime
from datetime import timedelta
from logging import getLogger
from uuid import UUID
from onyx.configs.constants import MessageType
from onyx.db.chat import create_chat_session
@@ -14,26 +13,18 @@ from onyx.db.models import ChatSession
logger = getLogger(__name__)
def seed_chat_history(
num_sessions: int,
num_messages: int,
days: int,
user_id: UUID | None = None,
persona_id: int | None = None,
) -> None:
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
"""Utility function to seed chat history for testing.
num_sessions: the number of sessions to seed
num_messages: the number of messages to seed per sessions
days: the number of days looking backwards from the current time over which to randomize
the times.
user_id: optional user to associate with sessions
persona_id: optional persona/assistant to associate with sessions
"""
with get_session_with_current_tenant() as db_session:
logger.info(f"Seeding {num_sessions} sessions.")
for y in range(0, num_sessions):
create_chat_session(db_session, f"pytest_session_{y}", user_id, persona_id)
create_chat_session(db_session, f"pytest_session_{y}", None, None)
# randomize all session times
logger.info(f"Seeding {num_messages} messages per session.")

View File

@@ -121,7 +121,6 @@ class VespaDocumentUserFields:
"""
user_projects: list[int] | None = None
personas: list[int] | None = None
@dataclass

View File

@@ -148,7 +148,6 @@ class MetadataUpdateRequest(BaseModel):
hidden: bool | None = None
secondary_index_updated: bool | None = None
project_ids: set[int] | None = None
persona_ids: set[int] | None = None
class IndexRetrievalFilters(BaseModel):

View File

@@ -50,7 +50,6 @@ from onyx.document_index.opensearch.schema import DocumentSchema
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
from onyx.document_index.opensearch.search import DocumentQuery
from onyx.document_index.opensearch.search import (
@@ -216,7 +215,6 @@ def _convert_onyx_chunk_to_opensearch_document(
# OpenSearch and it will not store any data at all for this field, which
# is different from supplying an empty list.
user_projects=chunk.user_project or None,
personas=chunk.personas or None,
primary_owners=get_experts_stores_representations(
chunk.source_document.primary_owners
),
@@ -364,11 +362,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
if user_fields and user_fields.user_projects
else None
),
persona_ids=(
set(user_fields.personas)
if user_fields and user_fields.personas
else None
),
)
try:
@@ -716,10 +709,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
properties_to_update[USER_PROJECTS_FIELD_NAME] = list(
update_request.project_ids
)
if update_request.persona_ids is not None:
properties_to_update[PERSONAS_FIELD_NAME] = list(
update_request.persona_ids
)
if not properties_to_update:
if len(update_request.document_ids) > 1:

View File

@@ -41,7 +41,6 @@ IMAGE_FILE_ID_FIELD_NAME = "image_file_id"
SOURCE_LINKS_FIELD_NAME = "source_links"
DOCUMENT_SETS_FIELD_NAME = "document_sets"
USER_PROJECTS_FIELD_NAME = "user_projects"
PERSONAS_FIELD_NAME = "personas"
DOCUMENT_ID_FIELD_NAME = "document_id"
CHUNK_INDEX_FIELD_NAME = "chunk_index"
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
@@ -157,7 +156,6 @@ class DocumentChunk(BaseModel):
document_sets: list[str] | None = None
user_projects: list[int] | None = None
personas: list[int] | None = None
primary_owners: list[str] | None = None
secondary_owners: list[str] | None = None
@@ -487,7 +485,6 @@ class DocumentSchema:
# Product-specific fields.
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
USER_PROJECTS_FIELD_NAME: {"type": "integer"},
PERSONAS_FIELD_NAME: {"type": "integer"},
PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"},
SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"},
# OpenSearch metadata fields.

View File

@@ -181,11 +181,6 @@ schema {{ schema_name }} {
rank: filter
attribute: fast-search
}
field personas type array<int> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
}
# If using different tokenization settings, the fieldset has to be removed, and the field must

View File

@@ -689,9 +689,6 @@ class VespaIndex(DocumentIndex):
project_ids: set[int] | None = None
if user_fields is not None and user_fields.user_projects is not None:
project_ids = set(user_fields.user_projects)
persona_ids: set[int] | None = None
if user_fields is not None and user_fields.personas is not None:
persona_ids = set(user_fields.personas)
update_request = MetadataUpdateRequest(
document_ids=[doc_id],
doc_id_to_chunk_cnt={
@@ -702,7 +699,6 @@ class VespaIndex(DocumentIndex):
boost=fields.boost if fields is not None else None,
hidden=fields.hidden if fields is not None else None,
project_ids=project_ids,
persona_ids=persona_ids,
)
vespa_document_index.update([update_request])

View File

@@ -46,7 +46,6 @@ from onyx.document_index.vespa_constants import METADATA
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import METADATA_SUFFIX
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
from onyx.document_index.vespa_constants import SECTION_CONTINUATION
@@ -219,7 +218,6 @@ def _index_vespa_chunk(
# still called `image_file_name` in Vespa for backwards compatibility
IMAGE_FILE_NAME: chunk.image_file_id,
USER_PROJECT: chunk.user_project if chunk.user_project is not None else [],
PERSONAS: chunk.personas if chunk.personas is not None else [],
BOOST: chunk.boost,
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
}

View File

@@ -183,10 +183,6 @@ def _update_single_chunk(
model_config = {"frozen": True}
assign: list[int]
class _Personas(BaseModel):
model_config = {"frozen": True}
assign: list[int]
class _VespaPutFields(BaseModel):
model_config = {"frozen": True}
# The names of these fields are based the Vespa schema. Changes to the
@@ -197,7 +193,6 @@ def _update_single_chunk(
access_control_list: _AccessControl | None = None
hidden: _Hidden | None = None
user_project: _UserProjects | None = None
personas: _Personas | None = None
class _VespaPutRequest(BaseModel):
model_config = {"frozen": True}
@@ -232,11 +227,6 @@ def _update_single_chunk(
if update_request.project_ids is not None
else None
)
personas_update: _Personas | None = (
_Personas(assign=list(update_request.persona_ids))
if update_request.persona_ids is not None
else None
)
vespa_put_fields = _VespaPutFields(
boost=boost_update,
@@ -244,7 +234,6 @@ def _update_single_chunk(
access_control_list=access_update,
hidden=hidden_update,
user_project=user_projects_update,
personas=personas_update,
)
vespa_put_request = _VespaPutRequest(
@@ -565,9 +554,10 @@ class VespaDocumentIndex(DocumentIndex):
num_to_retrieve: int,
) -> list[InferenceChunk]:
vespa_where_clauses = build_vespa_filters(filters)
# Avoid over-fetching a very large candidate set for global-phase reranking.
# Keep enough headroom for quality while capping cost on larger indices.
target_hits = min(max(4 * num_to_retrieve, 100), RERANK_COUNT)
# Needs to be at least as much as the rerank-count value set in the
# Vespa schema config. Otherwise we would be getting fewer results than
# expected for reranking.
target_hits = max(10 * num_to_retrieve, RERANK_COUNT)
yql = (
YQL_BASE.format(index_name=self._index_name)

View File

@@ -58,7 +58,6 @@ DOCUMENT_SETS = "document_sets"
USER_FILE = "user_file"
USER_FOLDER = "user_folder"
USER_PROJECT = "user_project"
PERSONAS = "personas"
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
METADATA = "metadata"
METADATA_LIST = "metadata_list"

View File

@@ -20,20 +20,7 @@ class ImageGenerationProviderCredentials(BaseModel):
custom_config: dict[str, str] | None = None
class ReferenceImage(BaseModel):
data: bytes
mime_type: str
class ImageGenerationProvider(abc.ABC):
@property
def supports_reference_images(self) -> bool:
return False
@property
def max_reference_images(self) -> int:
return 0
@classmethod
@abc.abstractmethod
def validate_credentials(
@@ -76,7 +63,6 @@ class ImageGenerationProvider(abc.ABC):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
"""Generates an image based on a prompt."""

View File

@@ -5,16 +5,12 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
class AzureImageGenerationProvider(ImageGenerationProvider):
_GPT_IMAGE_MODEL_PREFIX = "gpt-image-"
_DALL_E_2_MODEL_NAME = "dall-e-2"
def __init__(
self,
api_key: str,
@@ -56,25 +52,6 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
deployment_name=credentials.deployment_name,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# Azure GPT image models support up to 16 input images for edits.
return 16
def _normalize_model_name(self, model: str) -> str:
return model.rsplit("/", 1)[-1]
def _model_supports_image_edits(self, model: str) -> bool:
normalized_model = self._normalize_model_name(model)
return (
normalized_model.startswith(self._GPT_IMAGE_MODEL_PREFIX)
or normalized_model == self._DALL_E_2_MODEL_NAME
)
def generate_image(
self,
prompt: str,
@@ -82,44 +59,13 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation
deployment = self._deployment_name or model
model_name = f"azure/{deployment}"
if reference_images:
if not self._model_supports_image_edits(model):
raise ValueError(
f"Model '{model}' does not support image edits with reference images."
)
normalized_model = self._normalize_model_name(model)
if (
normalized_model == self._DALL_E_2_MODEL_NAME
and len(reference_images) > 1
):
raise ValueError(
"Model 'dall-e-2' only supports a single reference image for edits."
)
from litellm import image_edit
return image_edit(
image=[image.data for image in reference_images],
prompt=prompt,
model=model_name,
api_key=self._api_key,
api_base=self._api_base,
api_version=self._api_version,
size=size,
n=n,
quality=quality,
**kwargs,
)
from litellm import image_generation
return image_generation(
prompt=prompt,
model=model_name,

View File

@@ -5,16 +5,12 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
class OpenAIImageGenerationProvider(ImageGenerationProvider):
_GPT_IMAGE_MODEL_PREFIX = "gpt-image-"
_DALL_E_2_MODEL_NAME = "dall-e-2"
def __init__(
self,
api_key: str,
@@ -42,25 +38,6 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
api_base=credentials.api_base,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# GPT image models support up to 16 input images for edits.
return 16
def _normalize_model_name(self, model: str) -> str:
return model.rsplit("/", 1)[-1]
def _model_supports_image_edits(self, model: str) -> bool:
normalized_model = self._normalize_model_name(model)
return (
normalized_model.startswith(self._GPT_IMAGE_MODEL_PREFIX)
or normalized_model == self._DALL_E_2_MODEL_NAME
)
def generate_image(
self,
prompt: str,
@@ -68,38 +45,8 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
if reference_images:
if not self._model_supports_image_edits(model):
raise ValueError(
f"Model '{model}' does not support image edits with reference images."
)
normalized_model = self._normalize_model_name(model)
if (
normalized_model == self._DALL_E_2_MODEL_NAME
and len(reference_images) > 1
):
raise ValueError(
"Model 'dall-e-2' only supports a single reference image for edits."
)
from litellm import image_edit
return image_edit(
image=[image.data for image in reference_images],
prompt=prompt,
model=model,
api_key=self._api_key,
api_base=self._api_base,
size=size,
n=n,
quality=quality,
**kwargs,
)
from litellm import image_generation
return image_generation(

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
import base64
import json
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
@@ -11,7 +9,6 @@ from pydantic import BaseModel
from onyx.image_gen.exceptions import ImageProviderCredentialsError
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -54,15 +51,6 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
vertex_credentials=vertex_credentials,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# Gemini image editing supports up to 14 input images.
return 14
def generate_image(
self,
prompt: str,
@@ -70,18 +58,8 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
if reference_images:
return self._generate_image_with_reference_images(
prompt=prompt,
model=model,
size=size,
n=n,
reference_images=reference_images,
)
from litellm import image_generation
return image_generation(
@@ -96,99 +74,6 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
**kwargs,
)
def _generate_image_with_reference_images(
self,
prompt: str,
model: str,
size: str,
n: int,
reference_images: list[ReferenceImage],
) -> ImageGenerationResponse:
from google import genai
from google.genai import types as genai_types
from google.oauth2 import service_account
from litellm.types.utils import ImageObject
from litellm.types.utils import ImageResponse
service_account_info = json.loads(self._vertex_credentials)
credentials = service_account.Credentials.from_service_account_info(
service_account_info,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
client = genai.Client(
vertexai=True,
project=self._vertex_project,
location=self._vertex_location,
credentials=credentials,
)
parts: list[genai_types.Part] = [
genai_types.Part.from_bytes(data=image.data, mime_type=image.mime_type)
for image in reference_images
]
parts.append(genai_types.Part.from_text(text=prompt))
config = genai_types.GenerateContentConfig(
response_modalities=["TEXT", "IMAGE"],
candidate_count=max(1, n),
image_config=genai_types.ImageConfig(
aspect_ratio=_map_size_to_aspect_ratio(size)
),
)
model_name = model.replace("vertex_ai/", "")
response = client.models.generate_content(
model=model_name,
contents=genai_types.Content(
role="user",
parts=parts,
),
config=config,
)
generated_data: list[ImageObject] = []
for candidate in response.candidates or []:
candidate_content = candidate.content
if not candidate_content:
continue
for part in candidate_content.parts or []:
inline_data = part.inline_data
if not inline_data or inline_data.data is None:
continue
if isinstance(inline_data.data, bytes):
b64_json = base64.b64encode(inline_data.data).decode("utf-8")
elif isinstance(inline_data.data, str):
b64_json = inline_data.data
else:
continue
generated_data.append(
ImageObject(
b64_json=b64_json,
revised_prompt=prompt,
)
)
if not generated_data:
raise RuntimeError("No image data returned from Vertex AI.")
return ImageResponse(
created=int(datetime.now().timestamp()),
data=generated_data,
)
def _map_size_to_aspect_ratio(size: str) -> str:
return {
"1024x1024": "1:1",
"1792x1024": "16:9",
"1024x1792": "9:16",
"1536x1024": "3:2",
"1024x1536": "2:3",
}.get(size, "1:1")
def _parse_to_vertex_credentials(
credentials: ImageGenerationProviderCredentials,

View File

@@ -146,7 +146,6 @@ class DocumentIndexingBatchAdapter:
doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
context.id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in context.id_to_boost_map

View File

@@ -182,7 +182,7 @@ class UserFileIndexingAdapter:
user_project=user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=[],
# we are going to index userfiles only once, so we just set the boost to the default
boost=DEFAULT_BOOST,
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],

View File

@@ -49,7 +49,6 @@ from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import IndexingBatchAdapter
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.factory import get_llm_for_contextual_rag
@@ -229,8 +228,6 @@ def index_doc_batch_prepare(
) -> DocumentBatchPrepareContext | None:
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
This preceeds indexing it into the actual document index."""
documents = sanitize_documents_for_postgres(documents)
# Create a trimmed list of docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
document_ids: list[str] = [document.id for document in documents]

View File

@@ -112,7 +112,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess"
document_sets: set[str]
user_project: list[int]
personas: list[int]
boost: int
aggregated_chunk_boost_factor: float
# Full ancestor path from root hierarchy node to document's parent.
@@ -127,7 +126,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess",
document_sets: set[str],
user_project: list[int],
personas: list[int],
boost: int,
aggregated_chunk_boost_factor: float,
tenant_id: str,
@@ -139,7 +137,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access=access,
document_sets=document_sets,
user_project=user_project,
personas=personas,
boost=boost,
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
tenant_id=tenant_id,

View File

@@ -1,150 +0,0 @@
from typing import Any
from onyx.access.models import ExternalAccess
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
def _sanitize_string(value: str) -> str:
return value.replace("\x00", "")
def _sanitize_json_like(value: Any) -> Any:
if isinstance(value, str):
return _sanitize_string(value)
if isinstance(value, list):
return [_sanitize_json_like(item) for item in value]
if isinstance(value, tuple):
return tuple(_sanitize_json_like(item) for item in value)
if isinstance(value, dict):
sanitized: dict[Any, Any] = {}
for key, nested_value in value.items():
cleaned_key = _sanitize_string(key) if isinstance(key, str) else key
sanitized[cleaned_key] = _sanitize_json_like(nested_value)
return sanitized
return value
def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
return expert.model_copy(
update={
"display_name": (
_sanitize_string(expert.display_name)
if expert.display_name is not None
else None
),
"first_name": (
_sanitize_string(expert.first_name)
if expert.first_name is not None
else None
),
"middle_initial": (
_sanitize_string(expert.middle_initial)
if expert.middle_initial is not None
else None
),
"last_name": (
_sanitize_string(expert.last_name)
if expert.last_name is not None
else None
),
"email": (
_sanitize_string(expert.email) if expert.email is not None else None
),
}
)
def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess:
return ExternalAccess(
external_user_emails={
_sanitize_string(email) for email in external_access.external_user_emails
},
external_user_group_ids={
_sanitize_string(group_id)
for group_id in external_access.external_user_group_ids
},
is_public=external_access.is_public,
)
def sanitize_document_for_postgres(document: Document) -> Document:
cleaned_doc = document.model_copy(deep=True)
cleaned_doc.id = _sanitize_string(cleaned_doc.id)
cleaned_doc.semantic_identifier = _sanitize_string(cleaned_doc.semantic_identifier)
if cleaned_doc.title is not None:
cleaned_doc.title = _sanitize_string(cleaned_doc.title)
if cleaned_doc.parent_hierarchy_raw_node_id is not None:
cleaned_doc.parent_hierarchy_raw_node_id = _sanitize_string(
cleaned_doc.parent_hierarchy_raw_node_id
)
cleaned_doc.metadata = {
_sanitize_string(key): (
[_sanitize_string(item) for item in value]
if isinstance(value, list)
else _sanitize_string(value)
)
for key, value in cleaned_doc.metadata.items()
}
if cleaned_doc.doc_metadata is not None:
cleaned_doc.doc_metadata = _sanitize_json_like(cleaned_doc.doc_metadata)
if cleaned_doc.primary_owners is not None:
cleaned_doc.primary_owners = [
_sanitize_expert_info(expert) for expert in cleaned_doc.primary_owners
]
if cleaned_doc.secondary_owners is not None:
cleaned_doc.secondary_owners = [
_sanitize_expert_info(expert) for expert in cleaned_doc.secondary_owners
]
if cleaned_doc.external_access is not None:
cleaned_doc.external_access = _sanitize_external_access(
cleaned_doc.external_access
)
for section in cleaned_doc.sections:
if section.link is not None:
section.link = _sanitize_string(section.link)
if section.text is not None:
section.text = _sanitize_string(section.text)
if section.image_file_id is not None:
section.image_file_id = _sanitize_string(section.image_file_id)
return cleaned_doc
def sanitize_documents_for_postgres(documents: list[Document]) -> list[Document]:
return [sanitize_document_for_postgres(document) for document in documents]
def sanitize_hierarchy_node_for_postgres(node: HierarchyNode) -> HierarchyNode:
cleaned_node = node.model_copy(deep=True)
cleaned_node.raw_node_id = _sanitize_string(cleaned_node.raw_node_id)
cleaned_node.display_name = _sanitize_string(cleaned_node.display_name)
if cleaned_node.raw_parent_id is not None:
cleaned_node.raw_parent_id = _sanitize_string(cleaned_node.raw_parent_id)
if cleaned_node.link is not None:
cleaned_node.link = _sanitize_string(cleaned_node.link)
if cleaned_node.external_access is not None:
cleaned_node.external_access = _sanitize_external_access(
cleaned_node.external_access
)
return cleaned_node
def sanitize_hierarchy_nodes_for_postgres(
nodes: list[HierarchyNode],
) -> list[HierarchyNode]:
return [sanitize_hierarchy_node_for_postgres(node) for node in nodes]

View File

@@ -64,6 +64,21 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"anthropic.claude-3-7-sonnet-20240620-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -144,6 +159,11 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"apac.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"apac.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -1300,6 +1320,11 @@
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"bedrock/us-gov-east-1/anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -1340,6 +1365,16 @@
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"bedrock/us-gov-west-1/anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"bedrock/us-gov-west-1/anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -1470,6 +1505,26 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-7-sonnet-20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"claude-3-7-sonnet-latest": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-7-sonnet@20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"claude-3-haiku-20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"claude-4-opus-20250514": {
"display_name": "Claude Opus 4",
"model_vendor": "anthropic",
@@ -1650,6 +1705,16 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"eu.anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"eu.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -3161,6 +3226,15 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
},
"openrouter/anthropic/claude-3-haiku-20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"openrouter/anthropic/claude-3-sonnet": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic"
@@ -3175,6 +3249,16 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3.7-sonnet": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3.7-sonnet:beta": {
"display_name": "Claude Sonnet 3.7:beta",
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-haiku-4.5": {
"display_name": "Claude Haiku 4.5",
"model_vendor": "anthropic",
@@ -3666,6 +3750,16 @@
"model_vendor": "anthropic",
"model_version": "20241022"
},
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"us.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"us.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -3785,6 +3879,20 @@
"model_vendor": "anthropic",
"model_version": "20240620"
},
"vertex_ai/claude-3-7-sonnet@20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"vertex_ai/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
},
"vertex_ai/claude-3-haiku@20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"vertex_ai/claude-3-sonnet": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic"

View File

@@ -1,7 +1,5 @@
import json
import pathlib
import threading
import time
from onyx.llm.constants import LlmProviderNames
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
@@ -25,11 +23,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_RECOMMENDATIONS_CACHE_TTL_SECONDS = 300
_recommendations_cache_lock = threading.Lock()
_cached_recommendations: LLMRecommendations | None = None
_cached_recommendations_time: float = 0.0
def _get_provider_to_models_map() -> dict[str, list[str]]:
"""Lazy-load provider model mappings to avoid importing litellm at module level.
@@ -48,40 +41,19 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
}
def _load_bundled_recommendations() -> LLMRecommendations:
def get_recommendations() -> LLMRecommendations:
"""Get the recommendations from the GitHub config."""
recommendations_from_github = fetch_llm_recommendations_from_github()
if recommendations_from_github:
return recommendations_from_github
# Fall back to json bundled with code
json_path = pathlib.Path(__file__).parent / "recommended-models.json"
with open(json_path, "r") as f:
json_config = json.load(f)
return LLMRecommendations.model_validate(json_config)
def get_recommendations() -> LLMRecommendations:
"""Get the recommendations, with an in-memory cache to avoid
hitting GitHub on every API request."""
global _cached_recommendations, _cached_recommendations_time
now = time.monotonic()
if (
_cached_recommendations is not None
and (now - _cached_recommendations_time) < _RECOMMENDATIONS_CACHE_TTL_SECONDS
):
return _cached_recommendations
with _recommendations_cache_lock:
# Double-check after acquiring lock
if (
_cached_recommendations is not None
and (time.monotonic() - _cached_recommendations_time)
< _RECOMMENDATIONS_CACHE_TTL_SECONDS
):
return _cached_recommendations
recommendations_from_github = fetch_llm_recommendations_from_github()
result = recommendations_from_github or _load_bundled_recommendations()
_cached_recommendations = result
_cached_recommendations_time = time.monotonic()
return result
recommendations_from_json = LLMRecommendations.model_validate(json_config)
return recommendations_from_json
def is_obsolete_model(model_name: str, provider: str) -> bool:

View File

@@ -97,9 +97,6 @@ from onyx.server.features.web_search.api import router as web_search_router
from onyx.server.federated.api import router as federated_router
from onyx.server.kg.api import admin_router as kg_admin_router
from onyx.server.manage.administrative import router as admin_router
from onyx.server.manage.code_interpreter.api import (
admin_router as code_interpreter_admin_router,
)
from onyx.server.manage.discord_bot.api import router as discord_bot_router
from onyx.server.manage.embedding.api import admin_router as embedding_admin_router
from onyx.server.manage.embedding.api import basic_router as embedding_router
@@ -424,9 +421,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, llm_admin_router)
include_router_with_global_prefix_prepended(application, kg_admin_router)
include_router_with_global_prefix_prepended(application, llm_router)
include_router_with_global_prefix_prepended(
application, code_interpreter_admin_router
)
include_router_with_global_prefix_prepended(
application, image_generation_admin_router
)

View File

@@ -1,68 +1,14 @@
import re
from typing import Any
from mistune import create_markdown
from mistune import HTMLRenderer
_CITATION_LINK_PATTERN = re.compile(r"\[\[\d+\]\]\(")
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
"""Extract markdown link destination, allowing nested parentheses in the URL."""
depth = 0
i = start_idx
while i < len(message):
curr = message[i]
if curr == "\\":
i += 2
continue
if curr == "(":
depth += 1
elif curr == ")":
if depth == 0:
return message[start_idx:i], i
depth -= 1
i += 1
return message[start_idx:], None
def _normalize_citation_link_destinations(message: str) -> str:
"""Wrap citation URLs in angle brackets so markdown parsers handle parentheses safely."""
if "[[" not in message:
return message
normalized_parts: list[str] = []
cursor = 0
while match := _CITATION_LINK_PATTERN.search(message, cursor):
normalized_parts.append(message[cursor : match.end()])
destination_start = match.end()
destination, end_idx = _extract_link_destination(message, destination_start)
if end_idx is None:
normalized_parts.append(message[destination_start:])
return "".join(normalized_parts)
already_wrapped = destination.startswith("<") and destination.endswith(">")
if destination and not already_wrapped:
destination = f"<{destination}>"
normalized_parts.append(destination)
normalized_parts.append(")")
cursor = end_idx + 1
normalized_parts.append(message[cursor:])
return "".join(normalized_parts)
def format_slack_message(message: str | None) -> str:
if message is None:
return ""
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
normalized_message = _normalize_citation_link_destinations(message)
result = md(normalized_message)
result = md(message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
return result

View File

@@ -1,6 +1,6 @@
# ruff: noqa: E501, W605 start
# If there are any tools, this section is included, the sections below are for the available tools
TOOL_SECTION_HEADER = "\n# Tools\n\n"
TOOL_SECTION_HEADER = "\n\n# Tools\n"
# This section is included if there are search type tools, currently internal_search and web_search
@@ -16,10 +16,11 @@ When searching for information, if the initial results cannot fully answer the u
Do not repeat the same or very similar queries if it already has been run in the chat history.
If it is unclear which tool to use, consider using multiple in parallel to be efficient with time.
""".lstrip()
"""
INTERNAL_SEARCH_GUIDANCE = """
## internal_search
Use the `internal_search` tool to search connected applications for information. Some examples of when to use `internal_search` include:
- Internal information: any time where there may be some information stored in internal applications that could help better answer the query.
@@ -27,31 +28,34 @@ Use the `internal_search` tool to search connected applications for information.
- Keyword Queries: queries that are heavily keyword based are often internal document search queries.
- Ambiguity: questions about something that is not widely known or understood.
Never provide more than 3 queries at once to `internal_search`.
""".lstrip()
"""
WEB_SEARCH_GUIDANCE = """
## web_search
Use the `web_search` tool to access up-to-date information from the web. Some examples of when to use `web_search` include:
- Freshness: when the answer might be enhanced by up-to-date information on a topic. Very important for topics that are changing or evolving.
- Accuracy: if the cost of outdated/inaccurate information is high.
- Niche Information: when detailed info is not widely known or understood (but is likely found on the internet).{site_colon_disabled}
""".lstrip()
"""
WEB_SEARCH_SITE_DISABLED_GUIDANCE = """
Do not use the "site:" operator in your web search queries.
""".lstrip()
""".rstrip()
OPEN_URLS_GUIDANCE = """
## open_url
Use the `open_url` tool to read the content of one or more URLs. Use this tool to access the contents of the most promising web pages from your web searches or user specified URLs. \
You can open many URLs at once by passing multiple URLs in the array if multiple pages seem promising. Prioritize the most promising pages and reputable sources. \
Do not open URLs that are image files like .png, .jpg, etc.
You should almost always use open_url after a web_search call. Use this tool when a user asks about a specific provided URL.
""".lstrip()
"""
PYTHON_TOOL_GUIDANCE = """
## python
Use the `python` tool to execute Python code in an isolated sandbox. The tool will respond with the output of the execution or time out after 60.0 seconds.
Any files uploaded to the chat will be automatically be available in the execution environment's current directory. \
@@ -60,21 +64,21 @@ Use this to give the user a way to download the file OR to display generated ima
Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
Use `openpyxl` to read and write Excel files. You have access to libraries like numpy, pandas, scipy, matplotlib, and PIL.
IMPORTANT: each call to this tool is independent. Variables from previous calls will NOT be available in the current call.
""".lstrip()
"""
GENERATE_IMAGE_GUIDANCE = """
## generate_image
NEVER use generate_image unless the user specifically requests an image.
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
the `file_id` values returned by earlier `generate_image` tool results.
""".lstrip()
"""
MEMORY_GUIDANCE = """
## add_memory
Use the `add_memory` tool for facts shared by the user that should be remembered for future conversations. \
Only add memories that are specific, likely to remain true, and likely to be useful later. \
Focus on enduring preferences, long-term goals, stable constraints, and explicit "remember this" type requests.
""".lstrip()
"""
TOOL_CALL_FAILURE_PROMPT = """
LLM attempted to call a tool but failed. Most likely the tool name or arguments were misspelled.

View File

@@ -1,36 +1,40 @@
# ruff: noqa: E501, W605 start
USER_INFORMATION_HEADER = "\n# User Information\n\n"
USER_INFORMATION_HEADER = "\n\n# User Information\n"
BASIC_INFORMATION_PROMPT = """
## Basic Information
User name: {user_name}
User email: {user_email}{user_role}
""".lstrip()
"""
# This line only shows up if the user has configured their role.
USER_ROLE_PROMPT = """
User role: {user_role}
""".lstrip()
"""
# Team information should be a paragraph style description of the user's team.
TEAM_INFORMATION_PROMPT = """
## Team Information
{team_information}
""".lstrip()
"""
# User preferences should be a paragraph style description of the user's preferences.
USER_PREFERENCES_PROMPT = """
## User Preferences
{user_preferences}
""".lstrip()
"""
# User memories should look something like:
# - Memory 1
# - Memory 2
# - Memory 3
USER_MEMORIES_PROMPT = """
## User Memories
{user_memories}
""".lstrip()
"""
# ruff: noqa: E501, W605 end

View File

@@ -109,7 +109,6 @@ class TenantRedis(redis.Redis):
"unlock",
"get",
"set",
"setex",
"delete",
"exists",
"incrby",

View File

@@ -103,7 +103,6 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingMode
from onyx.db.enums import ProcessingMode
from onyx.db.federated import fetch_all_federated_connectors_parallel
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempts_by_status
@@ -988,7 +987,6 @@ def get_connector_status(
user=user,
eager_load_connector=True,
eager_load_credential=True,
eager_load_user=True,
get_editable=False,
)
@@ -1002,23 +1000,11 @@ def get_connector_status(
relationship.user_group_id
)
# Pre-compute credential_ids per connector to avoid N+1 lazy loads
connector_to_credential_ids: dict[int, list[int]] = {}
for cc_pair in cc_pairs:
connector_to_credential_ids.setdefault(cc_pair.connector_id, []).append(
cc_pair.credential_id
)
return [
ConnectorStatus(
cc_pair_id=cc_pair.id,
name=cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair.connector,
credential_ids=connector_to_credential_ids.get(
cc_pair.connector_id, []
),
),
connector=ConnectorSnapshot.from_connector_db_model(cc_pair.connector),
credential=CredentialSnapshot.from_credential_db_model(cc_pair.credential),
access_type=cc_pair.access_type,
groups=group_cc_pair_relationships_dict.get(cc_pair.id, []),
@@ -1073,27 +1059,15 @@ def get_connector_indexing_status(
parallel_functions: list[tuple[CallableProtocol, tuple[Any, ...]]] = [
# Get editable connector/credential pairs
(
lambda: get_connector_credential_pairs_for_user_parallel(
user, True, None, True, True, False, True, request.source
),
(),
get_connector_credential_pairs_for_user_parallel,
(user, True, None, True, True, True, True, request.source),
),
# Get federated connectors
(fetch_all_federated_connectors_parallel, ()),
# Get most recent index attempts
(
lambda: get_latest_index_attempts_parallel(
request.secondary_index, True, False
),
(),
),
(get_latest_index_attempts_parallel, (request.secondary_index, True, False)),
# Get most recent finished index attempts
(
lambda: get_latest_index_attempts_parallel(
request.secondary_index, True, True
),
(),
),
(get_latest_index_attempts_parallel, (request.secondary_index, True, True)),
]
if user and user.role == UserRole.ADMIN:
@@ -1110,10 +1084,8 @@ def get_connector_indexing_status(
parallel_functions.append(
# Get non-editable connector/credential pairs
(
lambda: get_connector_credential_pairs_for_user_parallel(
user, False, None, True, True, False, True, request.source
),
(),
get_connector_credential_pairs_for_user_parallel,
(user, False, None, True, True, True, True, request.source),
),
)
@@ -1939,7 +1911,6 @@ Tenant ID: {tenant_id}
class BasicCCPairInfo(BaseModel):
has_successful_run: bool
source: DocumentSource
status: ConnectorCredentialPairStatus
@router.get("/connector-status", tags=PUBLIC_API_TAGS)
@@ -1953,17 +1924,13 @@ def get_basic_connector_indexing_status(
get_editable=False,
user=user,
)
# NOTE: This endpoint excludes Craft connectors
return [
BasicCCPairInfo(
has_successful_run=cc_pair.last_successful_index_time is not None,
source=cc_pair.connector.source,
status=cc_pair.status,
)
for cc_pair in cc_pairs
if cc_pair.connector.source != DocumentSource.INGESTION_API
and cc_pair.processing_mode == ProcessingMode.REGULAR
]

View File

@@ -365,8 +365,7 @@ class CCPairFullInfo(BaseModel):
in_repeated_error_state=cc_pair_model.in_repeated_error_state,
num_docs_indexed=num_docs_indexed,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_model.connector,
credential_ids=[cc_pair_model.credential_id],
cc_pair_model.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_model.credential

View File

@@ -762,43 +762,6 @@ def download_webapp(
)
@router.get("/{session_id}/download-directory/{path:path}")
def download_directory(
session_id: UUID,
path: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Response:
"""
Download a directory as a zip file.
Returns the specified directory as a zip archive.
"""
user_id: UUID = user.id
session_manager = SessionManager(db_session)
try:
result = session_manager.download_directory(session_id, user_id, path)
except ValueError as e:
error_message = str(e)
if "path traversal" in error_message.lower():
raise HTTPException(status_code=403, detail="Access denied")
raise HTTPException(status_code=400, detail=error_message)
if result is None:
raise HTTPException(status_code=404, detail="Directory not found")
zip_bytes, filename = result
return Response(
content=zip_bytes,
media_type="application/zip",
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
},
)
@router.post("/{session_id}/upload", response_model=UploadResponse)
def upload_file_endpoint(
session_id: UUID,

View File

@@ -107,23 +107,27 @@ def get_or_create_craft_connector(db_session: Session, user: User) -> tuple[int,
)
for cc_pair in cc_pairs:
if (
cc_pair.connector.source == DocumentSource.CRAFT_FILE
and cc_pair.creator_id == user.id
):
if cc_pair.connector.source == DocumentSource.CRAFT_FILE:
return cc_pair.connector.id, cc_pair.credential.id
# No cc_pair for this user — find or create the shared CRAFT_FILE connector
# Check for orphaned connector (created but cc_pair creation failed previously)
existing_connectors = fetch_connectors(
db_session, sources=[DocumentSource.CRAFT_FILE]
)
connector_id: int | None = None
orphaned_connector = None
for conn in existing_connectors:
if conn.name == USER_LIBRARY_CONNECTOR_NAME:
connector_id = conn.id
if conn.name != USER_LIBRARY_CONNECTOR_NAME:
continue
if not conn.credentials:
orphaned_connector = conn
break
if connector_id is None:
if orphaned_connector:
connector_id = orphaned_connector.id
logger.info(
f"Found orphaned User Library connector {connector_id}, completing setup"
)
else:
connector_data = ConnectorBase(
name=USER_LIBRARY_CONNECTOR_NAME,
source=DocumentSource.CRAFT_FILE,

View File

@@ -68,7 +68,6 @@ from onyx.server.features.build.db.sandbox import create_sandbox__no_commit
from onyx.server.features.build.db.sandbox import get_running_sandbox_count_by_tenant
from onyx.server.features.build.db.sandbox import get_sandbox_by_session_id
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
from onyx.server.features.build.db.sandbox import get_snapshots_for_session
from onyx.server.features.build.db.sandbox import update_sandbox_heartbeat
from onyx.server.features.build.db.sandbox import update_sandbox_status__no_commit
from onyx.server.features.build.sandbox import get_sandbox_manager
@@ -647,30 +646,16 @@ class SessionManager:
if sandbox and sandbox.status.is_active():
# Quick health check to verify sandbox is actually responsive
# AND verify the session workspace still exists on disk
# (it may have been wiped if the sandbox was re-provisioned)
is_healthy = self._sandbox_manager.health_check(sandbox.id, timeout=5.0)
workspace_exists = (
is_healthy
and self._sandbox_manager.session_workspace_exists(
sandbox.id, existing.id
)
)
if is_healthy and workspace_exists:
if self._sandbox_manager.health_check(sandbox.id, timeout=5.0):
logger.info(
f"Returning existing empty session {existing.id} for user {user_id}"
)
return existing
elif not is_healthy:
else:
logger.warning(
f"Empty session {existing.id} has unhealthy sandbox {sandbox.id}. "
f"Deleting and creating fresh session."
)
else:
logger.warning(
f"Empty session {existing.id} workspace missing in sandbox "
f"{sandbox.id}. Deleting and creating fresh session."
)
else:
logger.warning(
f"Empty session {existing.id} has no active sandbox "
@@ -1050,23 +1035,6 @@ class SessionManager:
# workspace cleanup fails (e.g., if pod is already terminated)
logger.warning(f"Failed to cleanup session workspace {session_id}: {e}")
# Delete snapshot files from S3 before removing DB records
snapshots = get_snapshots_for_session(self._db_session, session_id)
if snapshots:
from onyx.file_store.file_store import get_default_file_store
from onyx.server.features.build.sandbox.manager.snapshot_manager import (
SnapshotManager,
)
snapshot_manager = SnapshotManager(get_default_file_store())
for snapshot in snapshots:
try:
snapshot_manager.delete_snapshot(snapshot.storage_path)
except Exception as e:
logger.warning(
f"Failed to delete snapshot file {snapshot.storage_path}: {e}"
)
# Delete session (uses flush, caller commits)
return delete_build_session__no_commit(session_id, user_id, self._db_session)
@@ -1935,94 +1903,6 @@ class SessionManager:
return zip_buffer.getvalue(), filename
def download_directory(
self,
session_id: UUID,
user_id: UUID,
path: str,
) -> tuple[bytes, str] | None:
"""
Create a zip file of an arbitrary directory in the session workspace.
Args:
session_id: The session UUID
user_id: The user ID to verify ownership
path: Relative path to the directory (within session workspace)
Returns:
Tuple of (zip_bytes, filename) or None if session not found
Raises:
ValueError: If path traversal attempted or path is not a directory
"""
# Verify session ownership
session = get_build_session(session_id, user_id, self._db_session)
if session is None:
return None
sandbox = get_sandbox_by_user_id(self._db_session, user_id)
if sandbox is None:
return None
# Check if directory exists
try:
self._sandbox_manager.list_directory(
sandbox_id=sandbox.id,
session_id=session_id,
path=path,
)
except ValueError:
return None
# Recursively collect all files
def collect_files(dir_path: str) -> list[tuple[str, str]]:
"""Collect all files recursively, returning (full_path, arcname) tuples."""
files: list[tuple[str, str]] = []
try:
entries = self._sandbox_manager.list_directory(
sandbox_id=sandbox.id,
session_id=session_id,
path=dir_path,
)
for entry in entries:
if entry.is_directory:
files.extend(collect_files(entry.path))
else:
# arcname is relative to the target directory
prefix_len = len(path) + 1 # +1 for trailing slash
arcname = entry.path[prefix_len:]
files.append((entry.path, arcname))
except ValueError:
pass
return files
file_list = collect_files(path)
# Create zip file in memory
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
for full_path, arcname in file_list:
try:
content = self._sandbox_manager.read_file(
sandbox_id=sandbox.id,
session_id=session_id,
path=full_path,
)
zip_file.writestr(arcname, content)
except ValueError:
pass
zip_buffer.seek(0)
# Use the directory name for the zip filename
dir_name = Path(path).name
safe_name = "".join(
c if c.isalnum() or c in ("-", "_", ".") else "_" for c in dir_name
)
filename = f"{safe_name}.zip"
return zip_buffer.getvalue(), filename
# =========================================================================
# File System Operations
# =========================================================================
@@ -2057,18 +1937,11 @@ class SessionManager:
return None
# Use sandbox manager to list directory (works for both local and K8s)
# If the directory doesn't exist (e.g., session workspace not yet loaded),
# return an empty listing rather than erroring out.
try:
raw_entries = self._sandbox_manager.list_directory(
sandbox_id=sandbox.id,
session_id=session_id,
path=path,
)
except ValueError as e:
if "path traversal" in str(e).lower():
raise
return DirectoryListing(path=path, entries=[])
raw_entries = self._sandbox_manager.list_directory(
sandbox_id=sandbox.id,
session_id=session_id,
path=path,
)
# Filter hidden files and directories
entries: list[FileSystemEntry] = [

View File

@@ -111,8 +111,7 @@ class DocumentSet(BaseModel):
id=cc_pair.id,
name=cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair.connector,
credential_ids=[cc_pair.credential_id],
cc_pair.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair.credential

View File

@@ -12,18 +12,11 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.background.celery.tasks.user_file_processing.tasks import (
enqueue_user_file_project_sync_task,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
get_user_file_project_sync_queue_depth,
)
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import UserFileStatus
from onyx.db.models import ChatSession
@@ -34,7 +27,6 @@ from onyx.db.models import UserProject
from onyx.db.persona import get_personas_by_ids
from onyx.db.projects import get_project_token_count
from onyx.db.projects import upload_files_to_user_files_with_indexing
from onyx.redis.redis_pool import get_redis_client
from onyx.server.features.projects.models import CategorizedFilesSnapshot
from onyx.server.features.projects.models import ChatSessionRequest
from onyx.server.features.projects.models import TokenCountResponse
@@ -55,33 +47,6 @@ class UserFileDeleteResult(BaseModel):
assistant_names: list[str] = []
def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None:
queue_depth = get_user_file_project_sync_queue_depth(client_app)
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
logger.warning(
f"Skipping immediate project sync for user_file_id={user_file_id} due to "
f"queue depth {queue_depth}>{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}. "
"It will be picked up by beat later."
)
return
redis_client = get_redis_client(tenant_id=tenant_id)
enqueued = enqueue_user_file_project_sync_task(
celery_app=client_app,
redis_client=redis_client,
user_file_id=user_file_id,
tenant_id=tenant_id,
priority=OnyxCeleryPriority.HIGHEST,
)
if not enqueued:
logger.info(
f"Skipped duplicate project sync enqueue for user_file_id={user_file_id}"
)
return
logger.info(f"Triggered project sync for user_file_id={user_file_id}")
@router.get("", tags=PUBLIC_API_TAGS)
def get_projects(
user: User = Depends(current_user),
@@ -224,7 +189,15 @@ def unlink_user_file_from_project(
db_session.commit()
tenant_id = get_current_tenant_id()
_trigger_user_file_project_sync(user_file.id, tenant_id)
task = client_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=OnyxCeleryPriority.HIGHEST,
)
logger.info(
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
)
return Response(status_code=204)
@@ -268,7 +241,15 @@ def link_user_file_to_project(
db_session.commit()
tenant_id = get_current_tenant_id()
_trigger_user_file_project_sync(user_file.id, tenant_id)
task = client_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=OnyxCeleryPriority.HIGHEST,
)
logger.info(
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
)
return UserFileSnapshot.from_model(user_file)

View File

@@ -1,47 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.db.code_interpreter import fetch_code_interpreter_server
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.server.manage.code_interpreter.models import CodeInterpreterServer
from onyx.server.manage.code_interpreter.models import CodeInterpreterServerHealth
from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
admin_router = APIRouter(prefix="/admin/code-interpreter")
@admin_router.get("/health")
def get_code_interpreter_health(
_: User = Depends(current_admin_user),
) -> CodeInterpreterServerHealth:
try:
client = CodeInterpreterClient()
return CodeInterpreterServerHealth(healthy=client.health())
except ValueError:
return CodeInterpreterServerHealth(healthy=False)
@admin_router.get("")
def get_code_interpreter(
_: User = Depends(current_admin_user), db_session: Session = Depends(get_session)
) -> CodeInterpreterServer:
ci_server = fetch_code_interpreter_server(db_session)
return CodeInterpreterServer(enabled=ci_server.server_enabled)
@admin_router.put("")
def update_code_interpreter(
update: CodeInterpreterServer,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_code_interpreter_server_enabled(
db_session=db_session,
enabled=update.enabled,
)

View File

@@ -1,9 +0,0 @@
from pydantic import BaseModel
class CodeInterpreterServer(BaseModel):
enabled: bool
class CodeInterpreterServerHealth(BaseModel):
healthy: bool

View File

@@ -35,18 +35,6 @@ if TYPE_CHECKING:
pass
class EmailInviteStatus(str, Enum):
SENT = "SENT"
NOT_CONFIGURED = "NOT_CONFIGURED"
SEND_FAILED = "SEND_FAILED"
DISABLED = "DISABLED"
class BulkInviteResponse(BaseModel):
invited_count: int
email_invite_status: EmailInviteStatus
class VersionResponse(BaseModel):
backend_version: str

View File

@@ -36,7 +36,6 @@ from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import AuthBackend
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
from onyx.configs.app_configs import NUM_FREE_TRIAL_USER_INVITES
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
@@ -79,10 +78,8 @@ from onyx.server.documents.models import PaginatedReturn
from onyx.server.features.projects.models import UserFileSnapshot
from onyx.server.manage.models import AllUsersResponse
from onyx.server.manage.models import AutoScrollRequest
from onyx.server.manage.models import BulkInviteResponse
from onyx.server.manage.models import ChatBackgroundRequest
from onyx.server.manage.models import DefaultAppModeRequest
from onyx.server.manage.models import EmailInviteStatus
from onyx.server.manage.models import MemoryItem
from onyx.server.manage.models import PersonalizationUpdateRequest
from onyx.server.manage.models import TenantInfo
@@ -371,7 +368,7 @@ def bulk_invite_users(
emails: list[str] = Body(..., embed=True),
current_user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> BulkInviteResponse:
) -> int:
"""emails are string validated. If any email fails validation, no emails are
invited and an exception is raised."""
tenant_id = get_current_tenant_id()
@@ -430,41 +427,34 @@ def bulk_invite_users(
number_of_invited_users = write_invited_users(all_emails)
# send out email invitations only to new users (not already invited or existing)
if not ENABLE_EMAIL_INVITES:
email_invite_status = EmailInviteStatus.DISABLED
elif not EMAIL_CONFIGURED:
email_invite_status = EmailInviteStatus.NOT_CONFIGURED
else:
if ENABLE_EMAIL_INVITES:
try:
for email in emails_needing_seats:
send_user_email_invite(email, current_user, AUTH_TYPE)
email_invite_status = EmailInviteStatus.SENT
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
email_invite_status = EmailInviteStatus.SEND_FAILED
if MULTI_TENANT and not DEV_MODE:
# for billing purposes, write to the control plane about the number of new users
try:
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_live_users_count(db_session))
except Exception as e:
logger.error(f"Failed to register tenant users: {str(e)}")
logger.info(
"Reverting changes: removing users from tenant and resetting invited users"
)
write_invited_users(initial_invited_users) # Reset to original state
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)(new_invited_emails, tenant_id)
raise e
if not MULTI_TENANT or DEV_MODE:
return number_of_invited_users
return BulkInviteResponse(
invited_count=number_of_invited_users,
email_invite_status=email_invite_status,
)
# for billing purposes, write to the control plane about the number of new users
try:
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_live_users_count(db_session))
return number_of_invited_users
except Exception as e:
logger.error(f"Failed to register tenant users: {str(e)}")
logger.info(
"Reverting changes: removing users from tenant and resetting invited users"
)
write_invited_users(initial_invited_users) # Reset to original state
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)(new_invited_emails, tenant_id)
raise e
@router.patch("/manage/admin/remove-invited-user", tags=PUBLIC_API_TAGS)

View File

@@ -587,7 +587,6 @@ def handle_send_chat_message(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
additional_context=chat_message_req.additional_context,
external_state_container=state_container,
)
result = gather_stream_full(packets, state_container)
@@ -610,7 +609,6 @@ def handle_send_chat_message(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
additional_context=chat_message_req.additional_context,
external_state_container=state_container,
):
yield get_json_line(obj.model_dump())

View File

@@ -125,11 +125,6 @@ class SendMessageRequest(BaseModel):
# - No CitationInfo packets are emitted during streaming
include_citations: bool = True
# Additional context injected into the LLM call but NOT stored in the DB
# (not shown in chat history). Used e.g. by the Chrome extension to pass
# the current tab URL when "Read this tab" is enabled.
additional_context: str | None = None
@model_validator(mode="after")
def check_chat_session_id_or_info(self) -> "SendMessageRequest":
# If neither is provided, default to creating a new chat session using the

View File

@@ -36,8 +36,6 @@ from onyx.server.query_and_chat.streaming_models import OpenUrlStart
from onyx.server.query_and_chat.streaming_models import OpenUrlUrls
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import ResearchAgentStart
@@ -52,7 +50,6 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
)
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from onyx.utils.logger import setup_logger
@@ -380,37 +377,6 @@ def create_memory_packets(
return packets
def create_python_tool_packets(
code: str,
stdout: str,
stderr: str,
file_ids: list[str],
turn_index: int,
tab_index: int = 0,
) -> list[Packet]:
"""Recreate PythonToolStart + PythonToolDelta + SectionEnd from the stored
tool call data so the frontend can display both the code and its output
on page reload."""
packets: list[Packet] = []
placement = Placement(turn_index=turn_index, tab_index=tab_index)
packets.append(Packet(placement=placement, obj=PythonToolStart(code=code)))
packets.append(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=stdout,
stderr=stderr,
file_ids=file_ids,
),
)
)
packets.append(Packet(placement=placement, obj=SectionEnd()))
return packets
def create_search_packets(
search_queries: list[str],
search_docs: list[SavedSearchDoc],
@@ -620,41 +586,6 @@ def translate_assistant_message_to_packets(
)
)
elif tool.in_code_tool_id == PythonTool.__name__:
code = cast(
str,
tool_call.tool_call_arguments.get("code", ""),
)
stdout = ""
stderr = ""
file_ids: list[str] = []
if tool_call.tool_call_response:
try:
response_data = json.loads(tool_call.tool_call_response)
stdout = response_data.get("stdout", "")
stderr = response_data.get("stderr", "")
generated_files = response_data.get(
"generated_files", []
)
file_ids = [
f.get("file_link", "").split("/")[-1]
for f in generated_files
if f.get("file_link")
]
except (json.JSONDecodeError, KeyError):
# Fall back to raw response as stdout
stdout = tool_call.tool_call_response
turn_tool_packets.extend(
create_python_tool_packets(
code=code,
stdout=stdout,
stderr=stderr,
file_ids=file_ids,
turn_index=turn_num,
tab_index=tool_call.tab_index,
)
)
else:
# Custom tool or unknown tool
turn_tool_packets.extend(

View File

@@ -24,7 +24,6 @@ from onyx.auth.users import get_user_manager
from onyx.auth.users import UserManager
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from onyx.configs.app_configs import SAML_CONF_DIR
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
@@ -124,12 +123,9 @@ async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]:
if request.client is None:
raise ValueError("Invalid request for SAML")
# Derive http_host and server_port from WEB_DOMAIN (a trusted env var)
# instead of X-Forwarded-* headers, which can be spoofed by an attacker
# to poison SAML redirect URLs (host header poisoning).
parsed_domain = urlparse(WEB_DOMAIN)
http_host = parsed_domain.hostname or request.client.host
server_port = parsed_domain.port or (443 if parsed_domain.scheme == "https" else 80)
# Use X-Forwarded headers if available
http_host = request.headers.get("X-Forwarded-Host") or request.client.host
server_port = request.headers.get("X-Forwarded-Port") or request.url.port
rv: dict[str, Any] = {
"http_host": http_host,

View File

@@ -57,7 +57,6 @@ class Settings(BaseModel):
anonymous_user_enabled: bool | None = None
invite_only_enabled: bool = False
deep_research_enabled: bool | None = None
search_ui_enabled: bool | None = None
# Enterprise features flag - set by license enforcement at runtime
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status

View File

@@ -245,11 +245,7 @@ def setup_postgres(db_session: Session) -> None:
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
if (
GEN_AI_API_KEY
and fetch_default_llm_model(db_session) is None
and not INTEGRATION_TESTS_MODE
):
if GEN_AI_API_KEY and fetch_default_llm_model(db_session) is None:
# Only for dev flows
logger.notice("Setting up default OpenAI LLM for dev.")

Some files were not shown because too many files have changed in this diff Show More