Compare commits

...

69 Commits

Author SHA1 Message Date
rohoswagger
43a59e4d74 chore: trim redundant source-count logging in generate_agents_md.py
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 14:33:17 -08:00
rohoswagger
dc8fc7eefc refactor: simplify to single-step AGENTS.md generation via stdin pipe
Instead of writing AGENTS.md with the placeholder and then calling the
script to replace it, pipe the template directly into the script via
stdin so the final file is written in one step.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 14:07:25 -08:00
rohoswagger
78f7914093 fix: invoke generate_agents_md.py in K8s to populate knowledge sources
The KNOWLEDGE_SOURCES_SECTION placeholder in AGENTS.md was never being
replaced in Kubernetes environments. The script existed in the container
image but was never called during session setup.

Now the K8s setup script calls generate_agents_md.py after writing
AGENTS.md, scanning the files symlink to populate the knowledge sources
section. Also updated the script to accept CLI args instead of env vars
to fit the invocation context.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 13:34:29 -08:00
Evan Lohn
f4d777b80d refactor: persona id in vector db (#8680) 2026-02-25 20:42:38 +00:00
acaprau
da4d57b5e3 chore(devtools): Make AGENTS.md reference contributing_guides/best_practices.md (#8760)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-25 20:27:12 +00:00
Evan Lohn
dcdcd067bd fix: drive 403 rate limits (#8762) 2026-02-25 20:12:36 +00:00
Evan Lohn
8b15a29723 feat: slab connector validation (#8758) 2026-02-25 20:00:42 +00:00
Danelegend
763853674f feat(ci): Add preview modal for data types (#8752) 2026-02-25 19:52:19 +00:00
Jamison Lahman
429b6f3465 fix(fe): modal aligning with detached element after navigation (#8676) 2026-02-25 19:33:07 +00:00
Danelegend
37d5be1b40 feat: python tool not added when no code interpretter server (#8749) 2026-02-25 19:17:42 +00:00
Jamison Lahman
8ab99dbb06 chore(fe): add hover style to AgentCard (#8689) 2026-02-25 19:08:00 +00:00
Jamison Lahman
52799e9c7a fix(fe): middle align human chat message text (#8756) 2026-02-25 19:00:01 +00:00
Jamison Lahman
aef009cc97 chore(fe): foldable buttons display text via tooltip when disabled (#8735) 2026-02-25 18:39:53 +00:00
Evan Lohn
18d1ea1770 fix: sharepoint driveItem perm sync (#8698) 2026-02-25 18:29:26 +00:00
Bo-Onyx
f336ad00f4 fix(user invitation): failed but no warning. (#8731)
Co-authored-by: Bo Yang <boyang@Bos-MacBook-Pro.local>
2026-02-25 17:23:39 +00:00
SubashMohan
0558e687d9 fix: persist onboarding dismissal in localStorage with user-specific keys (#8674) 2026-02-25 06:22:17 +00:00
roshan
784a99e24a updated demo data (#8748) 2026-02-24 19:59:46 -08:00
Justin Tahara
da1f5a11f4 chore(cherry-pick): Alerting on Failed Cherry-Picks (#8744) 2026-02-25 02:09:19 +00:00
Justin Tahara
5633805890 chore(devtools): Upgrade ods from 0.6.0 -> 0.6.1 (#8743) 2026-02-25 02:01:20 +00:00
Danelegend
0817b45ae1 feat: Get code interpreter config route (#8739) 2026-02-25 01:49:30 +00:00
Justin Tahara
af0e4bdebc fix(slack): Cleaning up URL Links (#8569) 2026-02-25 01:42:12 +00:00
Justin Tahara
4cd2320732 chore(cherry-pick): Add Github Label for PRs (#8736) 2026-02-25 00:46:12 +00:00
Danelegend
90a361f0e1 feat: code interpreter routes (#8670) 2026-02-24 16:27:10 -08:00
Justin Tahara
194efde97b chore(llm): Scaffolding for Nightly LLM Tests (#8704) 2026-02-25 00:06:24 +00:00
Danelegend
d922a42262 feat: code interpreter docker default deploy (#8672) 2026-02-24 23:51:19 +00:00
Danelegend
f00c3a486e feat: default deploy code interpreter - helm & bump version 0.3.0 (#8685) 2026-02-24 23:40:46 +00:00
Danelegend
192080c9e4 feat: default deploy code interpreter - restart_script (#8686) 2026-02-24 23:40:36 +00:00
Justin Tahara
c5787dc073 chore(image): Update test to be for Dall E 3 instead of 2 (#8732) 2026-02-24 22:53:31 +00:00
Justin Tahara
d424d6462c fix(sanitization): Centralizing DB Filters (#8730) 2026-02-24 22:28:25 +00:00
Jamison Lahman
ecea86deb6 chore(fe): only left input items flex (#8734) 2026-02-24 22:25:04 +00:00
Jamison Lahman
a5c1f50a8a chore(fe): update disabled "select" button color (#8733) 2026-02-24 22:03:52 +00:00
roshan
4a04cfd486 feat(craft): make output/ files downloadable from Artifacts tab (#8721)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-02-24 21:49:59 +00:00
Nikolas Garza
f22e9628db feat(scim): add additional entra id fields to ScimUserMapping (#8728) 2026-02-24 20:23:21 +00:00
Jamison Lahman
255ba10af6 chore(chat): consolidate chat message whitespacing style (#8696) 2026-02-24 20:02:28 +00:00
Justin Tahara
563202a080 feat(image): support Azure historical image context edits (#8726) 2026-02-24 19:21:30 +00:00
Evan Lohn
1062dc0743 fix: graph client env (#8727) 2026-02-24 18:46:49 +00:00
Justin Tahara
0826348568 feat(image): support OpenAI historical image context edits (#8725) 2026-02-24 18:45:56 +00:00
Justin Tahara
375079136d chore(cherry-pick): Assign merged-by user on beta cherry-pick PR (#8723) 2026-02-24 18:27:48 +00:00
Jamison Lahman
82aad5e253 fix(welcome): add back agent description (#8716) 2026-02-24 17:27:23 +00:00
Jamison Lahman
beb1c49c69 fix(fe): inline code-blocks respect header font-size (#8691) 2026-02-24 17:03:21 +00:00
Jamison Lahman
c4556515be fix(fe): rm non-admin-confirmation max-width (#8693) 2026-02-24 17:03:05 +00:00
SubashMohan
a4387f230b fix(popover): prevent viewport overflow with dynamic max-height and collision padding (#8675) 2026-02-24 10:27:36 +00:00
Evan Lohn
d91e452658 chore: version bumps for client libs (#8720) 2026-02-24 08:13:37 +00:00
Danelegend
dd274f8667 feat: code interpreter supports streaming (#8663) 2026-02-24 06:07:36 +00:00
roshan
2c82f0da16 fix(craft): delete S3 snapshot files when deleting a craft (#8718)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 05:58:29 +00:00
Raunak Bhagat
26101636f2 refactor: add new ContentAction component (#8695) 2026-02-24 05:13:18 +00:00
roshan
5e2c0c6cf4 fix(nrf): hide search toggle when search mode is unavailable (#8717)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 20:43:19 -08:00
roshan
33b64db498 fix(extensions): fix base url for chrome extension to (#8714) 2026-02-23 20:18:05 -08:00
roshan
b925cc1a56 feat(chrome-extension): add tab reading to side panel (#8571)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-24 01:17:57 +00:00
Danelegend
bac4b7c945 fix: preview markdown formatting (#8667) 2026-02-24 01:13:52 +00:00
Evan Lohn
6f6ef1e657 chore: coerce doc metadata (#8703) 2026-02-24 01:12:11 +00:00
Danelegend
885c69f460 feat: Improve csv preview modal (#8702) 2026-02-24 01:00:20 +00:00
Danelegend
4b837303ff feat(code-interpreter): Seed code interpreter server row (#8701) 2026-02-24 00:59:49 +00:00
Justin Tahara
d856a9befb fix(projects): Guardrails for Project User Files (#8644) 2026-02-24 00:21:57 +00:00
Justin Tahara
adade353c5 fix(api): Improving the API handling of threads (#8573) 2026-02-24 00:04:21 +00:00
Nikolas Garza
3cb6ec2f85 fix: patch prometheus metrics in daily test fixture (#8699) 2026-02-24 00:02:56 +00:00
Wenxi
691eebf00a fix: remove user info requirement for craft onboarding modal (#8697) 2026-02-23 23:52:17 +00:00
Danelegend
905b6633e6 chore: preview modal (#8665) 2026-02-23 23:40:55 +00:00
Justin Tahara
fd088196ff fix(search): Improve Speed (#8430) 2026-02-23 22:45:18 +00:00
Jamison Lahman
cafbf5b8be chore(playwright): warn user if setup takes longer than usual (#8690) 2026-02-23 22:23:58 +00:00
roshan
1235181559 fix(ui): Clean up NRF settings button styling (#8678)
Co-authored-by: Claude <noreply@anthropic.com>
2026-02-23 21:25:43 +00:00
Justin Tahara
caa2e45632 fix(db): Multitenant Schema migration update (#8679) 2026-02-23 21:25:26 +00:00
Justin Tahara
9c62e03120 chore(ods): Automated Cherry-pick backport (#8642) 2026-02-23 21:15:09 +00:00
Nikolas Garza
0937305064 feat(scim): Okta compatibility + provider abstraction (#8568) 2026-02-23 21:09:18 +00:00
Wenxi
e4c06570e3 fix: domain rules for signup on cloud (#8671) 2026-02-23 20:27:37 +00:00
roshan
78fc7c86d7 fix: Handle unauthenticated state gracefully on NRF page (#8491)
Co-authored-by: Claude <noreply@anthropic.com>
2026-02-23 19:26:38 +00:00
Raunak Bhagat
84d3aea847 refactor: migrate Web Search page to SettingsLayouts + Content (#8662) 2026-02-23 13:38:37 +00:00
Danelegend
00a404d3cd feat: Add code interpreter server db model (#8669) 2026-02-23 05:09:59 +00:00
Wenxi
787cf90d96 chore: set trial api usage to 0 and show ui (#8664) 2026-02-23 01:41:23 +00:00
234 changed files with 8298 additions and 2671 deletions

View File

@@ -8,5 +8,5 @@
## Additional Options
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
- [ ] [Optional] Override Linear Check

View File

@@ -0,0 +1,161 @@
name: Post-Merge Beta Cherry-Pick
on:
push:
branches:
- main
permissions:
contents: write
pull-requests: write
jobs:
cherry-pick-to-latest-release:
outputs:
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
pr_number: ${{ steps.gate.outputs.pr_number }}
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Resolve merged PR and checkbox state
id: gate
env:
GH_TOKEN: ${{ github.token }}
run: |
# For the commit that triggered this workflow (HEAD on main), fetch all
# associated PRs and keep only the PR that was actually merged into main
# with this exact merge commit SHA.
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
if [ "${match_count}" -gt 1 ]; then
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
fi
if [ -z "$pr_number" ]; then
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
exit 0
fi
# Read the PR once so we can gate behavior and infer preferred actor.
pr_json="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}")"
pr_body="$(printf '%s' "$pr_json" | jq -r '.body // ""')"
merged_by="$(printf '%s' "$pr_json" | jq -r '.merged_by.login // ""')"
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
echo "merged_by=$merged_by" >> "$GITHUB_OUTPUT"
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox checked for PR #${pr_number}."
exit 0
fi
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
- name: Checkout repository
if: steps.gate.outputs.should_cherrypick == 'true'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: true
ref: main
- name: Install the latest version of uv
if: steps.gate.outputs.should_cherrypick == 'true'
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Configure git identity
if: steps.gate.outputs.should_cherrypick == 'true'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Create cherry-pick PR to latest release
id: run_cherry_pick
if: steps.gate.outputs.should_cherrypick == 'true'
continue-on-error: true
env:
GH_TOKEN: ${{ github.token }}
GITHUB_TOKEN: ${{ github.token }}
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
run: |
set -o pipefail
output_file="$(mktemp)"
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
exit_code="${PIPESTATUS[0]}"
if [ "${exit_code}" -eq 0 ]; then
echo "status=success" >> "$GITHUB_OUTPUT"
exit 0
fi
echo "status=failure" >> "$GITHUB_OUTPUT"
reason="command-failed"
if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then
reason="merge-conflict"
fi
echo "reason=${reason}" >> "$GITHUB_OUTPUT"
{
echo "details<<EOF"
tail -n 40 "$output_file"
echo "EOF"
} >> "$GITHUB_OUTPUT"
- name: Mark workflow as failed if cherry-pick failed
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
run: |
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
exit 1
notify-slack-on-cherry-pick-failure:
needs:
- cherry-pick-to-latest-release
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
runs-on: ubuntu-slim
timeout-minutes: 10
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build cherry-pick failure summary
id: failure-summary
env:
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
run: |
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
reason_text="cherry-pick command failed"
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
reason_text="merge conflict during cherry-pick"
fi
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
if [ -n "${details_excerpt}" ]; then
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
fi
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
- name: Notify #cherry-pick-prs about cherry-pick failure
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
title: "🚨 Automated Cherry-Pick Failed"
ref-name: ${{ github.ref_name }}

View File

@@ -1,28 +0,0 @@
name: Require beta cherry-pick consideration
concurrency:
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
pull_request:
types: [opened, edited, reopened, synchronize]
permissions:
contents: read
jobs:
beta-cherrypick-check:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR body for beta cherry-pick consideration
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
echo "Cherry-pick consideration box is checked. Check passed."
exit 0
fi
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
exit 1

View File

@@ -116,7 +116,6 @@ jobs:
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
CODE_INTERPRETER_BETA_ENABLED=true
DISABLE_TELEMETRY=true
OPENSEARCH_FOR_ONYX_ENABLED=true
EOF

View File

@@ -20,6 +20,7 @@ env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
SLACK_BOT_TOKEN_TEST_SPACE: ${{ secrets.SLACK_BOT_TOKEN_TEST_SPACE }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
@@ -423,6 +424,7 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
@@ -443,6 +445,7 @@ jobs:
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
-e ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${{ matrix.edition == 'ee' && 'true' || 'false' }} \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/${{ matrix.test-dir.path }}
@@ -701,6 +704,7 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \

View File

@@ -548,7 +548,7 @@ class in the utils over directly calling the APIs with a library like `requests`
calling the utilities directly (e.g. do NOT create admin users with
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
A great example of this type of test is `backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py`.
To run them:
@@ -616,3 +616,9 @@ This is a minimal list - feel free to include more. Do NOT write code as part of
Keep it high level. You can reference certain files or functions though.
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
## Best Practices
In addition to the other content in this file, best practices for contributing
to the codebase can be found at `contributing_guides/best_practices.md`.
Understand its contents and follow them.

View File

@@ -21,15 +21,14 @@ import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, NamedTuple
from typing import NamedTuple
from alembic.config import Config
from alembic.script import ScriptDirectory
from sqlalchemy import text
from onyx.db.engine.sql_engine import is_valid_schema_name
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
from shared_configs.configs import TENANT_ID_PREFIX
@@ -105,56 +104,6 @@ def get_head_revision() -> str | None:
return script.get_current_head()
def get_schemas_needing_migration(
tenant_schemas: List[str], head_rev: str
) -> List[str]:
"""Return only schemas whose current alembic version is not at head."""
if not tenant_schemas:
return []
engine = SqlEngine.get_engine()
with engine.connect() as conn:
# Find which schemas actually have an alembic_version table
rows = conn.execute(
text(
"SELECT table_schema FROM information_schema.tables "
"WHERE table_name = 'alembic_version' "
"AND table_schema = ANY(:schemas)"
),
{"schemas": tenant_schemas},
)
schemas_with_table = set(row[0] for row in rows)
# Schemas without the table definitely need migration
needs_migration = [s for s in tenant_schemas if s not in schemas_with_table]
if not schemas_with_table:
return needs_migration
# Validate schema names before interpolating into SQL
for schema in schemas_with_table:
if not is_valid_schema_name(schema):
raise ValueError(f"Invalid schema name: {schema}")
# Single query to get every schema's current revision at once.
# Use integer tags instead of interpolating schema names into
# string literals to avoid quoting issues.
schema_list = list(schemas_with_table)
union_parts = [
f'SELECT {i} AS idx, version_num FROM "{schema}".alembic_version'
for i, schema in enumerate(schema_list)
]
rows = conn.execute(text(" UNION ALL ".join(union_parts)))
version_by_schema = {schema_list[row[0]]: row[1] for row in rows}
needs_migration.extend(
s for s in schemas_with_table if version_by_schema.get(s) != head_rev
)
return needs_migration
def run_migrations_parallel(
schemas: list[str],
max_workers: int,

View File

@@ -0,0 +1,29 @@
"""code interpreter seed
Revision ID: 07b98176f1de
Revises: 7cb492013621
Create Date: 2026-02-23 15:55:07.606784
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "07b98176f1de"
down_revision = "7cb492013621"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Seed the single instance of code_interpreter_server
# NOTE: There should only exist at most and at minimum 1 code_interpreter_server row
op.execute(
sa.text("INSERT INTO code_interpreter_server (server_enabled) VALUES (true)")
)
def downgrade() -> None:
op.execute(sa.text("DELETE FROM code_interpreter_server"))

View File

@@ -0,0 +1,48 @@
"""add enterprise and name fields to scim_user_mapping
Revision ID: 7616121f6e97
Revises: 07b98176f1de
Create Date: 2026-02-23 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7616121f6e97"
down_revision = "07b98176f1de"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("department", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("manager", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("given_name", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("family_name", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("scim_emails_json", sa.Text(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_emails_json")
op.drop_column("scim_user_mapping", "family_name")
op.drop_column("scim_user_mapping", "given_name")
op.drop_column("scim_user_mapping", "manager")
op.drop_column("scim_user_mapping", "department")

View File

@@ -0,0 +1,31 @@
"""code interpreter server model
Revision ID: 7cb492013621
Revises: 0bb4558f35df
Create Date: 2026-02-22 18:54:54.007265
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7cb492013621"
down_revision = "0bb4558f35df"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"code_interpreter_server",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column(
"server_enabled", sa.Boolean, nullable=False, server_default=sa.true()
),
)
def downgrade() -> None:
op.drop_table("code_interpreter_server")

View File

@@ -127,9 +127,14 @@ class ScimDAL(DAL):
self,
external_id: str,
user_id: UUID,
scim_username: str | None = None,
) -> ScimUserMapping:
"""Create a mapping between a SCIM externalId and an Onyx user."""
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
mapping = ScimUserMapping(
external_id=external_id,
user_id=user_id,
scim_username=scim_username,
)
self._session.add(mapping)
self._session.flush()
return mapping
@@ -248,11 +253,11 @@ class ScimDAL(DAL):
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[User, str | None]], int]:
) -> tuple[list[tuple[User, ScimUserMapping | None]], int]:
"""Query users with optional SCIM filter and pagination.
Returns:
A tuple of (list of (user, external_id) pairs, total_count).
A tuple of (list of (user, mapping) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
@@ -292,33 +297,104 @@ class ScimDAL(DAL):
users = list(
self._session.scalars(
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
).all()
)
.unique()
.all()
)
# Batch-fetch external IDs to avoid N+1 queries
ext_id_map = self._get_user_external_ids([u.id for u in users])
return [(u, ext_id_map.get(u.id)) for u in users], total
# Batch-fetch SCIM mappings to avoid N+1 queries
mapping_map = self._get_user_mappings_batch([u.id for u in users])
return [(u, mapping_map.get(u.id)) for u in users], total
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
def sync_user_external_id(
self,
user_id: UUID,
new_external_id: str | None,
scim_username: str | None = None,
) -> None:
"""Create, update, or delete the external ID mapping for a user."""
mapping = self.get_user_mapping_by_user_id(user_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
else:
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
)
elif mapping:
self.delete_user_mapping(mapping.id)
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
"""Batch-fetch external IDs for a list of user IDs."""
def _get_user_mappings_batch(
self, user_ids: list[UUID]
) -> dict[UUID, ScimUserMapping]:
"""Batch-fetch SCIM user mappings keyed by user ID."""
if not user_ids:
return {}
mappings = self._session.scalars(
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
).all()
return {m.user_id: m.external_id for m in mappings}
return {m.user_id: m for m in mappings}
def get_user_groups(self, user_id: UUID) -> list[tuple[int, str]]:
"""Get groups a user belongs to as ``(group_id, group_name)`` pairs.
Excludes groups marked for deletion.
"""
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_id == user_id)
).all()
group_ids = [r.user_group_id for r in rels]
if not group_ids:
return []
groups = self._session.scalars(
select(UserGroup).where(
UserGroup.id.in_(group_ids),
UserGroup.is_up_for_deletion.is_(False),
)
).all()
return [(g.id, g.name) for g in groups]
def get_users_groups_batch(
self, user_ids: list[UUID]
) -> dict[UUID, list[tuple[int, str]]]:
"""Batch-fetch group memberships for multiple users.
Returns a mapping of ``user_id → [(group_id, group_name), ...]``.
Avoids N+1 queries when building user list responses.
"""
if not user_ids:
return {}
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_id.in_(user_ids))
).all()
group_ids = list({r.user_group_id for r in rels})
if not group_ids:
return {}
groups = self._session.scalars(
select(UserGroup).where(
UserGroup.id.in_(group_ids),
UserGroup.is_up_for_deletion.is_(False),
)
).all()
groups_by_id = {g.id: g.name for g in groups}
result: dict[UUID, list[tuple[int, str]]] = {}
for r in rels:
if r.user_id and r.user_group_id in groups_by_id:
result.setdefault(r.user_id, []).append(
(r.user_group_id, groups_by_id[r.user_group_id])
)
return result
# ------------------------------------------------------------------
# Group mapping operations
@@ -483,9 +559,13 @@ class ScimDAL(DAL):
if not user_ids:
return []
users = self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
).all()
users = (
self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
)
.unique()
.all()
)
users_by_id = {u.id: u for u in users}
return [
@@ -504,9 +584,13 @@ class ScimDAL(DAL):
"""
if not uuids:
return []
existing_users = self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
).all()
existing_users = (
self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
)
.unique()
.all()
)
existing_ids = {u.id for u in existing_users}
return [uid for uid in uuids if uid not in existing_ids]

View File

@@ -34,7 +34,7 @@ class SendSearchQueryRequest(BaseModel):
filters: BaseFilters | None = None
num_docs_fed_to_llm_selection: int | None = None
run_query_expansion: bool = False
num_hits: int = 50
num_hits: int = 30
include_content: bool = False
stream: bool = False

View File

@@ -26,12 +26,10 @@ from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimError
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimResourceType
@@ -41,6 +39,8 @@ from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.providers.base import get_default_provider
from ee.onyx.server.scim.providers.base import ScimProvider
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
@@ -53,7 +53,6 @@ from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
@@ -63,6 +62,18 @@ scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
def _get_provider(
_token: ScimToken = Depends(verify_scim_token),
) -> ScimProvider:
"""Resolve the SCIM provider for the current request.
Currently returns OktaProvider for all requests. When multi-provider
support is added (ENG-3652), this will resolve based on token metadata
or tenant configuration — no endpoint changes required.
"""
return get_default_provider()
# ---------------------------------------------------------------------------
# Service Discovery Endpoints (unauthenticated)
# ---------------------------------------------------------------------------
@@ -100,28 +111,6 @@ def _scim_error_response(status: int, detail: str) -> JSONResponse:
)
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
"""Convert an Onyx User to a SCIM User resource representation."""
name = None
if user.personal_name:
parts = user.personal_name.split(" ", 1)
name = ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
return ScimUserResource(
id=str(user.id),
externalId=external_id,
userName=user.email,
name=name,
emails=[ScimEmail(value=user.email, type="work", primary=True)],
active=user.is_active,
meta=ScimMeta(resourceType="User"),
)
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
check_fn = fetch_ee_implementation_or_noop(
@@ -155,9 +144,10 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
"""
if not name:
return None
return name.formatted or " ".join(
part for part in [name.givenName, name.familyName] if part
)
# Build from givenName/familyName first — IdPs like Okta may send a stale
# ``formatted`` value while updating the individual name components.
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
return parts or name.formatted
# ---------------------------------------------------------------------------
@@ -171,6 +161,7 @@ def list_users(
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List users with optional SCIM filter and pagination."""
@@ -183,12 +174,19 @@ def list_users(
return _scim_error_response(400, str(e))
try:
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
users_with_mappings, total = dal.list_users(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
user_groups_map = dal.get_users_groups_batch([u.id for u, _ in users_with_mappings])
resources: list[ScimUserResource | ScimGroupResource] = [
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=user_groups_map.get(user.id, []),
scim_username=mapping.scim_username if mapping else None,
)
for user, mapping in users_with_mappings
]
return ScimListResponse(
@@ -203,6 +201,7 @@ def list_users(
def get_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Get a single user by ID."""
@@ -215,20 +214,26 @@ def get_user(
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
return _user_to_scim(user, mapping.external_id if mapping else None)
return provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=dal.get_user_groups(user.id),
scim_username=mapping.scim_username if mapping else None,
)
@scim_router.post("/Users", status_code=201, response_model=None)
def create_user(
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Create a new user from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
email = user_resource.userName.strip().lower()
email = user_resource.userName.strip()
# externalId is how the IdP correlates this user on subsequent requests.
# Without it, the IdP can't find the user and will try to re-create,
@@ -264,11 +269,14 @@ def create_user(
# Create SCIM mapping (externalId is validated above, always present)
external_id = user_resource.externalId
dal.create_user_mapping(external_id=external_id, user_id=user.id)
scim_username = user_resource.userName.strip()
dal.create_user_mapping(
external_id=external_id, user_id=user.id, scim_username=scim_username
)
dal.commit()
return _user_to_scim(user, external_id)
return provider.build_user_resource(user, external_id, scim_username=scim_username)
@scim_router.put("/Users/{user_id}", response_model=None)
@@ -276,6 +284,7 @@ def replace_user(
user_id: str,
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Replace a user entirely (RFC 7644 §3.5.1)."""
@@ -293,19 +302,27 @@ def replace_user(
if seat_error:
return _scim_error_response(403, seat_error)
personal_name = _scim_name_to_str(user_resource.name)
dal.update_user(
user,
email=user_resource.userName.strip().lower(),
email=user_resource.userName.strip(),
is_active=user_resource.active,
personal_name=_scim_name_to_str(user_resource.name),
personal_name=personal_name,
)
new_external_id = user_resource.externalId
dal.sync_user_external_id(user.id, new_external_id)
scim_username = user_resource.userName.strip()
dal.sync_user_external_id(user.id, new_external_id, scim_username=scim_username)
dal.commit()
return _user_to_scim(user, new_external_id)
return provider.build_user_resource(
user,
new_external_id,
groups=dal.get_user_groups(user.id),
scim_username=scim_username,
)
@scim_router.patch("/Users/{user_id}", response_model=None)
@@ -313,6 +330,7 @@ def patch_user(
user_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
"""Partially update a user (RFC 7644 §3.5.2).
@@ -330,11 +348,19 @@ def patch_user(
mapping = dal.get_user_mapping_by_user_id(user.id)
external_id = mapping.external_id if mapping else None
current_scim_username = mapping.scim_username if mapping else None
current = _user_to_scim(user, external_id)
current = provider.build_user_resource(
user,
external_id,
groups=dal.get_user_groups(user.id),
scim_username=current_scim_username,
)
try:
patched = apply_user_patch(patch_request.Operations, current)
patched = apply_user_patch(
patch_request.Operations, current, provider.ignored_patch_paths
)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
@@ -345,22 +371,40 @@ def patch_user(
if seat_error:
return _scim_error_response(403, seat_error)
# Track the scim_username — if userName was patched, update it
new_scim_username = patched.userName.strip() if patched.userName else None
# If displayName was explicitly patched (different from the original), use
# it as personal_name directly. Otherwise, derive from name components.
personal_name: str | None
if patched.displayName and patched.displayName != current.displayName:
personal_name = patched.displayName
else:
personal_name = _scim_name_to_str(patched.name)
dal.update_user(
user,
email=(
patched.userName.strip().lower()
if patched.userName.lower() != user.email
patched.userName.strip()
if patched.userName.strip().lower() != user.email.lower()
else None
),
is_active=patched.active if patched.active != user.is_active else None,
personal_name=_scim_name_to_str(patched.name),
personal_name=personal_name,
)
dal.sync_user_external_id(user.id, patched.externalId)
dal.sync_user_external_id(
user.id, patched.externalId, scim_username=new_scim_username
)
dal.commit()
return _user_to_scim(user, patched.externalId)
return provider.build_user_resource(
user,
patched.externalId,
groups=dal.get_user_groups(user.id),
scim_username=new_scim_username,
)
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
@@ -398,24 +442,6 @@ def delete_user(
# ---------------------------------------------------------------------------
def _group_to_scim(
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Convert an Onyx UserGroup to a SCIM Group resource."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
"""Parse *group_id* as int, look up the group, or return a 404 error."""
try:
@@ -474,6 +500,7 @@ def list_groups(
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
"""List groups with optional SCIM filter and pagination."""
@@ -491,7 +518,7 @@ def list_groups(
return _scim_error_response(400, str(e))
resources: list[ScimUserResource | ScimGroupResource] = [
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
provider.build_group_resource(group, dal.get_group_members(group.id), ext_id)
for group, ext_id in groups_with_ext_ids
]
@@ -507,6 +534,7 @@ def list_groups(
def get_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Get a single group by ID."""
@@ -521,13 +549,16 @@ def get_group(
mapping = dal.get_group_mapping_by_group_id(group.id)
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, mapping.external_id if mapping else None)
return provider.build_group_resource(
group, members, mapping.external_id if mapping else None
)
@scim_router.post("/Groups", status_code=201, response_model=None)
def create_group(
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Create a new group from a SCIM provisioning request."""
@@ -565,7 +596,7 @@ def create_group(
dal.commit()
members = dal.get_group_members(db_group.id)
return _group_to_scim(db_group, members, external_id)
return provider.build_group_resource(db_group, members, external_id)
@scim_router.put("/Groups/{group_id}", response_model=None)
@@ -573,6 +604,7 @@ def replace_group(
group_id: str,
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Replace a group entirely (RFC 7644 §3.5.1)."""
@@ -595,7 +627,7 @@ def replace_group(
dal.commit()
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, group_resource.externalId)
return provider.build_group_resource(group, members, group_resource.externalId)
@scim_router.patch("/Groups/{group_id}", response_model=None)
@@ -603,6 +635,7 @@ def patch_group(
group_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
"""Partially update a group (RFC 7644 §3.5.2).
@@ -621,11 +654,11 @@ def patch_group(
external_id = mapping.external_id if mapping else None
current_members = dal.get_group_members(group.id)
current = _group_to_scim(group, current_members, external_id)
current = provider.build_group_resource(group, current_members, external_id)
try:
patched, added_ids, removed_ids = apply_group_patch(
patch_request.Operations, current
patch_request.Operations, current, provider.ignored_patch_paths
)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
@@ -652,7 +685,7 @@ def patch_group(
dal.commit()
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, patched.externalId)
return provider.build_group_resource(group, members, patched.externalId)
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)

View File

@@ -63,6 +63,13 @@ class ScimMeta(BaseModel):
location: str | None = None
class ScimUserGroupRef(BaseModel):
"""Group reference within a User resource (RFC 7643 §4.1.2, read-only)."""
value: str
display: str | None = None
class ScimUserResource(BaseModel):
"""SCIM User resource representation (RFC 7643 §4.1).
@@ -76,8 +83,10 @@ class ScimUserResource(BaseModel):
externalId: str | None = None # IdP's identifier for this user
userName: str # Typically the user's email address
name: ScimName | None = None
displayName: str | None = None
emails: list[ScimEmail] = Field(default_factory=list)
active: bool = True
groups: list[ScimUserGroupRef] = Field(default_factory=list)
meta: ScimMeta | None = None
@@ -121,12 +130,40 @@ class ScimPatchOperationType(str, Enum):
REMOVE = "remove"
class ScimPatchResourceValue(BaseModel):
"""Partial resource dict for path-less PATCH replace operations.
When an IdP sends a PATCH without a ``path``, the ``value`` is a dict
of resource attributes to set. IdPs may include read-only fields
(``id``, ``schemas``, ``meta``) alongside actual changes — these are
stripped by the provider's ``ignored_patch_paths`` before processing.
``extra="allow"`` lets unknown attributes pass through so the patch
handler can decide what to do with them (ignore or reject).
"""
model_config = ConfigDict(extra="allow")
active: bool | None = None
userName: str | None = None
displayName: str | None = None
externalId: str | None = None
name: ScimName | None = None
members: list[ScimGroupMember] | None = None
id: str | None = None
schemas: list[str] | None = None
meta: ScimMeta | None = None
ScimPatchValue = str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None
class ScimPatchOperation(BaseModel):
"""Single PATCH operation (RFC 7644 §3.5.2)."""
op: ScimPatchOperationType
path: str | None = None
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
value: ScimPatchValue = None
class ScimPatchRequest(BaseModel):

View File

@@ -16,9 +16,12 @@ from __future__ import annotations
import re
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchResourceValue
from ee.onyx.server.scim.models import ScimPatchValue
from ee.onyx.server.scim.models import ScimUserResource
@@ -41,9 +44,15 @@ _MEMBER_FILTER_RE = re.compile(
def apply_user_patch(
operations: list[ScimPatchOperation],
current: ScimUserResource,
ignored_paths: frozenset[str] = frozenset(),
) -> ScimUserResource:
"""Apply SCIM PATCH operations to a user resource.
Args:
operations: The PATCH operations to apply.
current: The current user resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns a new ``ScimUserResource`` with the modifications applied.
The original object is not mutated.
@@ -55,9 +64,9 @@ def apply_user_patch(
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_user_replace(op, data, name_data)
_apply_user_replace(op, data, name_data, ignored_paths)
elif op.op == ScimPatchOperationType.ADD:
_apply_user_replace(op, data, name_data)
_apply_user_replace(op, data, name_data, ignored_paths)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on User resource"
@@ -71,30 +80,34 @@ def _apply_user_replace(
op: ScimPatchOperation,
data: dict,
name_data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace/add operation to user data."""
path = (op.path or "").lower()
if not path:
# No path — value is a dict of top-level attributes to set
if isinstance(op.value, dict):
for key, val in op.value.items():
_set_user_field(key.lower(), val, data, name_data)
# No path — value is a resource dict of top-level attributes to set
if isinstance(op.value, ScimPatchResourceValue):
for key, val in op.value.model_dump(exclude_unset=True).items():
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
_set_user_field(path, op.value, data, name_data)
_set_user_field(path, op.value, data, name_data, ignored_paths)
def _set_user_field(
path: str,
value: str | bool | dict | list | None,
value: ScimPatchValue,
data: dict,
name_data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Set a single field on user data by SCIM path."""
if path == "active":
if path in ignored_paths:
return
elif path == "active":
data["active"] = value
elif path == "username":
data["userName"] = value
@@ -107,7 +120,7 @@ def _set_user_field(
elif path == "name.formatted":
name_data["formatted"] = value
elif path == "displayname":
# Some IdPs send displayName on users; map to formatted name
data["displayName"] = value
name_data["formatted"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
@@ -116,9 +129,15 @@ def _set_user_field(
def apply_group_patch(
operations: list[ScimPatchOperation],
current: ScimGroupResource,
ignored_paths: frozenset[str] = frozenset(),
) -> tuple[ScimGroupResource, list[str], list[str]]:
"""Apply SCIM PATCH operations to a group resource.
Args:
operations: The PATCH operations to apply.
current: The current group resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns:
A tuple of (modified group, added member IDs, removed member IDs).
The caller uses the member ID lists to update the database.
@@ -133,7 +152,9 @@ def apply_group_patch(
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
_apply_group_replace(
op, data, current_members, added_ids, removed_ids, ignored_paths
)
elif op.op == ScimPatchOperationType.ADD:
_apply_group_add(op, current_members, added_ids)
elif op.op == ScimPatchOperationType.REMOVE:
@@ -154,38 +175,48 @@ def _apply_group_replace(
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace operation to group data."""
path = (op.path or "").lower()
if not path:
if isinstance(op.value, dict):
for key, val in op.value.items():
if isinstance(op.value, ScimPatchResourceValue):
dumped = op.value.model_dump(exclude_unset=True)
for key, val in dumped.items():
if key.lower() == "members":
_replace_members(val, current_members, added_ids, removed_ids)
else:
_set_group_field(key.lower(), val, data)
_set_group_field(key.lower(), val, data, ignored_paths)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
if path == "members":
_replace_members(op.value, current_members, added_ids, removed_ids)
_replace_members(
_members_to_dicts(op.value), current_members, added_ids, removed_ids
)
return
_set_group_field(path, op.value, data)
_set_group_field(path, op.value, data, ignored_paths)
def _members_to_dicts(
value: str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None,
) -> list[dict]:
"""Convert a member list value to a list of dicts for internal processing."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
return [m.model_dump(exclude_none=True) for m in value]
def _replace_members(
value: str | list | dict | bool | None,
value: list[dict],
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
) -> None:
"""Replace the entire group member list."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
old_ids = {m["value"] for m in current_members}
new_ids = {m.get("value", "") for m in value}
@@ -197,11 +228,14 @@ def _replace_members(
def _set_group_field(
path: str,
value: str | bool | dict | list | None,
value: ScimPatchValue,
data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Set a single field on group data by SCIM path."""
if path == "displayname":
if path in ignored_paths:
return
elif path == "displayname":
data["displayName"] = value
elif path == "externalid":
data["externalId"] = value
@@ -223,8 +257,10 @@ def _apply_group_add(
if not isinstance(op.value, list):
raise ScimPatchError("Add members requires a list value")
member_dicts = [m.model_dump(exclude_none=True) for m in op.value]
existing_ids = {m["value"] for m in members}
for member_data in op.value:
for member_data in member_dicts:
member_id = member_data.get("value", "")
if member_id and member_id not in existing_ids:
members.append(member_data)

View File

@@ -0,0 +1,123 @@
"""Base SCIM provider abstraction."""
from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from uuid import UUID
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserGroupRef
from ee.onyx.server.scim.models import ScimUserResource
from onyx.db.models import User
from onyx.db.models import UserGroup
class ScimProvider(ABC):
"""Base class for provider-specific SCIM behavior.
Subclass this to handle IdP-specific quirks. The base class provides
RFC 7643-compliant response builders that populate all standard fields.
"""
@property
@abstractmethod
def name(self) -> str:
"""Short identifier for this provider (e.g. ``"okta"``)."""
...
@property
@abstractmethod
def ignored_patch_paths(self) -> frozenset[str]:
"""SCIM attribute paths to silently skip in PATCH value-object dicts.
IdPs may include read-only or meta fields alongside actual changes
(e.g. Okta sends ``{"id": "...", "active": false}``). Paths listed
here are silently dropped instead of raising an error.
"""
...
def build_user_resource(
self,
user: User,
external_id: str | None = None,
groups: list[tuple[int, str]] | None = None,
scim_username: str | None = None,
) -> ScimUserResource:
"""Build a SCIM User response from an Onyx User.
Args:
user: The Onyx user model.
external_id: The IdP's external identifier for this user.
groups: List of ``(group_id, group_name)`` tuples for the
``groups`` read-only attribute. Pass ``None`` or ``[]``
for newly-created users.
scim_username: The original-case userName from the IdP. Falls
back to ``user.email`` (lowercase) when not available.
"""
group_refs = [
ScimUserGroupRef(value=str(gid), display=gname)
for gid, gname in (groups or [])
]
# Use original-case userName if stored, otherwise fall back to the
# lowercased email from the User model.
username = scim_username or user.email
return ScimUserResource(
id=str(user.id),
externalId=external_id,
userName=username,
name=self._build_scim_name(user),
displayName=user.personal_name,
emails=[ScimEmail(value=username, type="work", primary=True)],
active=user.is_active,
groups=group_refs,
meta=ScimMeta(resourceType="User"),
)
def build_group_resource(
self,
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Build a SCIM Group response from an Onyx UserGroup."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
@staticmethod
def _build_scim_name(user: User) -> ScimName | None:
"""Extract SCIM name components from a user's personal name."""
if not user.personal_name:
return None
parts = user.personal_name.split(" ", 1)
return ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
def get_default_provider() -> ScimProvider:
"""Return the default SCIM provider.
Currently returns ``OktaProvider`` since Okta is the primary supported
IdP. When provider detection is added (via token metadata or tenant
config), this can be replaced with dynamic resolution.
"""
from ee.onyx.server.scim.providers.okta import OktaProvider
return OktaProvider()

View File

@@ -0,0 +1,25 @@
"""Okta SCIM provider."""
from __future__ import annotations
from ee.onyx.server.scim.providers.base import ScimProvider
class OktaProvider(ScimProvider):
"""Okta SCIM provider.
Okta behavioral notes:
- Uses ``PATCH {"active": false}`` for deprovisioning (not DELETE)
- Sends path-less PATCH with value dicts containing extra fields
(``id``, ``schemas``)
- Expects ``displayName`` and ``groups`` in user responses
- Only uses ``eq`` operator for ``userName`` filter
"""
@property
def name(self) -> str:
return "okta"
@property
def ignored_patch_paths(self) -> frozenset[str]:
return frozenset({"id", "schemas", "meta"})

View File

@@ -277,13 +277,32 @@ def verify_email_domain(email: str) -> None:
detail="Email is not valid",
)
domain = email.split("@")[-1].lower()
local_part, domain = email.split("@")
domain = domain.lower()
if AUTH_TYPE == AuthType.CLOUD:
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
if domain == "googlemail.com":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
)
if "+" in local_part and domain != "onyx.app":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
},
)
# Check if email uses a disposable/temporary domain
if is_disposable_email(email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
detail={
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
},
)
# Check domain whitelist if configured

View File

@@ -22,6 +22,7 @@ from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import METADATA_SUFFIX
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER
@@ -58,6 +59,7 @@ FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
METADATA_SUFFIX,
DOCUMENT_SETS,
USER_PROJECT,
PERSONAS,
PRIMARY_OWNERS,
SECONDARY_OWNERS,
ACCESS_CONTROL_LIST,
@@ -276,6 +278,7 @@ def transform_vespa_chunks_to_opensearch_chunks(
)
)
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
personas: list[int] | None = vespa_chunk.get(PERSONAS)
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
@@ -325,6 +328,7 @@ def transform_vespa_chunks_to_opensearch_chunks(
metadata_suffix=metadata_suffix,
document_sets=document_sets,
user_projects=user_projects,
personas=personas,
primary_owners=primary_owners,
secondary_owners=secondary_owners,
tenant_id=tenant_state,

View File

@@ -5,8 +5,10 @@ from uuid import UUID
import httpx
import sqlalchemy as sa
from celery import Celery
from celery import shared_task
from celery import Task
from redis import Redis
from redis.lock import Lock as RedisLock
from retry import retry
from sqlalchemy import select
@@ -24,12 +26,14 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
@@ -75,10 +79,58 @@ def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
def _user_file_project_sync_queued_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_QUEUED_PREFIX}:{user_file_id}"
def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
return celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
)
def enqueue_user_file_project_sync_task(
*,
celery_app: Celery,
redis_client: Redis,
user_file_id: str | UUID,
tenant_id: str,
priority: OnyxCeleryPriority = OnyxCeleryPriority.HIGH,
) -> bool:
"""Enqueue a project-sync task if no matching queued task already exists."""
queued_key = _user_file_project_sync_queued_key(user_file_id)
# NX+EX gives us atomic dedupe and a self-healing TTL.
queued_guard_set = redis_client.set(
queued_key,
1,
nx=True,
ex=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
)
if not queued_guard_set:
return False
try:
celery_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=priority,
expires=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
)
except Exception:
# Roll back the queued guard if task publish fails.
redis_client.delete(queued_key)
raise
return True
@retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0))
def _visit_chunks(
*,
@@ -632,8 +684,8 @@ def process_single_user_file_delete(
ignore_result=True,
)
def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROJECT_SYNC status and enqueue per-file tasks."""
task_logger.info("check_for_user_file_project_sync - Starting")
"""Scan for user files needing project sync and enqueue per-file tasks."""
task_logger.info("Starting")
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
@@ -645,7 +697,16 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
queue_depth = get_user_file_project_sync_queue_depth(self.app)
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
task_logger.warning(
f"Queue depth {queue_depth} exceeds "
f"{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}, skipping enqueue for tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -661,19 +722,23 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
if not enqueue_user_file_project_sync_task(
celery_app=self.app,
redis_client=redis_client,
user_file_id=user_file_id,
tenant_id=tenant_id,
priority=OnyxCeleryPriority.HIGH,
)
):
skipped_guard += 1
continue
enqueued += 1
finally:
if lock.owned():
lock.release()
task_logger.info(
f"check_for_user_file_project_sync - Enqueued {enqueued} tasks for tenant={tenant_id}"
f"Enqueued {enqueued} "
f"Skipped guard {skipped_guard} tasks for tenant={tenant_id}"
)
return None
@@ -692,6 +757,8 @@ def process_single_user_file_project_sync(
)
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,

View File

@@ -58,6 +58,8 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
@@ -156,36 +158,7 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
logger.warning(
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
)
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
if cleaned_doc.title and "\x00" in cleaned_doc.title:
logger.warning(
f"NUL characters found in document title: {cleaned_doc.title}"
)
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
)
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
"\x00", ""
)
for section in cleaned_doc.sections:
if section.link is not None:
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
if isinstance(section, TextSection) and section.text is not None:
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
cleaned_batch.append(sanitize_document_for_postgres(doc))
return cleaned_batch
@@ -602,10 +575,13 @@ def connector_document_extraction(
# Process hierarchy nodes batch - upsert to Postgres and cache in Redis
if hierarchy_node_batch:
hierarchy_node_batch_cleaned = (
sanitize_hierarchy_nodes_for_postgres(hierarchy_node_batch)
)
with get_session_with_current_tenant() as db_session:
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=hierarchy_node_batch,
nodes=hierarchy_node_batch_cleaned,
source=db_connector.source,
commit=True,
is_connector_public=is_connector_public,
@@ -624,7 +600,7 @@ def connector_document_extraction(
)
logger.debug(
f"Persisted and cached {len(hierarchy_node_batch)} hierarchy nodes "
f"Persisted and cached {len(hierarchy_node_batch_cleaned)} hierarchy nodes "
f"for attempt={index_attempt_id}"
)

View File

@@ -30,6 +30,7 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchDocsResponse
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.memory import add_memory
from onyx.db.memory import update_memory_at_index
from onyx.db.memory import UserMemoryContext
@@ -656,7 +657,12 @@ def run_llm_loop(
fallback_extraction_attempted: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
# Fetch this in a short-lived session so the long-running stream loop does
# not pin a connection just to keep read state alive.
with get_session_with_current_tenant() as prompt_db_session:
default_base_system_prompt: str = get_default_base_system_prompt(
prompt_db_session
)
system_prompt = None
custom_agent_prompt_msg = None

View File

@@ -856,6 +856,11 @@ def handle_stream_message_objects(
reserved_tokens=reserved_token_count,
)
# Release any read transaction before entering the long-running LLM stream.
# Without this, the request-scoped session can keep a connection checked out
# for the full stream duration.
db_session.commit()
# The stream generator can resume on a different worker thread after early yields.
# Set this right before launching the LLM loop so run_in_background copies the right context.
if new_msg_req.mock_llm_response is not None:

View File

@@ -210,10 +210,10 @@ AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
REQUIRE_EMAIL_VERIFICATION = (
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
)
SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
SMTP_SERVER = os.environ.get("SMTP_SERVER") or ""
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
SMTP_USER = os.environ.get("SMTP_USER") or ""
SMTP_PASS = os.environ.get("SMTP_PASS") or ""
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
SENDGRID_API_KEY = os.environ.get("SENDGRID_API_KEY") or ""

View File

@@ -167,6 +167,14 @@ CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
# How long a queued user-file-project-sync task remains valid.
# Should be short enough to discard stale queue entries under load while still
# allowing workers enough time to pick up new tasks.
CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Max queue depth before user-file-project-sync producers stop enqueuing.
# This applies backpressure when workers are falling behind.
USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
@@ -459,6 +467,7 @@ class OnyxRedisLocks:
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_PROJECT_SYNC_QUEUED_PREFIX = "da_lock:user_file_project_sync_queued"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"

View File

@@ -16,6 +16,22 @@ from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
_RATE_LIMIT_REASONS = {"userRateLimitExceeded", "rateLimitExceeded"}
def _is_rate_limit_error(error: HttpError) -> bool:
"""Google sometimes returns rate-limit errors as 403 with reason
'userRateLimitExceeded' instead of 429. This helper detects both."""
if error.resp.status == 429:
return True
if error.resp.status != 403:
return False
error_details = getattr(error, "error_details", None) or []
for detail in error_details:
if isinstance(detail, dict) and detail.get("reason") in _RATE_LIMIT_REASONS:
return True
return "userRateLimitExceeded" in str(error) or "rateLimitExceeded" in str(error)
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. This is now addressed by checkpointing.
@@ -57,7 +73,7 @@ def _execute_with_retry(request: Any) -> Any:
except HttpError as error:
attempt += 1
if error.resp.status == 429:
if _is_rate_limit_error(error):
# Attempt to get 'Retry-After' from headers
retry_after = error.resp.get("Retry-After")
if retry_after:
@@ -140,16 +156,16 @@ def _execute_single_retrieval(
)
logger.error(f"Error executing request: {e}")
raise e
elif _is_rate_limit_error(e):
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.debug(f"Error executing request: {e}")
results = {}
else:
raise e
elif e.resp.status == 429:
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
else:
logger.exception("Error executing request:")
raise e

View File

@@ -0,0 +1,96 @@
"""Inverse mapping from user-facing Microsoft host URLs to the SDK's AzureEnvironment.
The office365 library's GraphClient requires an ``AzureEnvironment`` string
(e.g. ``"Global"``, ``"GCC High"``) to route requests to the correct national
cloud. Our connectors instead expose free-text ``authority_host`` and
``graph_api_host`` fields so the frontend doesn't need to know about SDK
internals.
This module bridges the gap: given the two host URLs the user configured, it
resolves the matching ``AzureEnvironment`` value (and the implied SharePoint
domain suffix) so callers can pass ``environment=…`` to ``GraphClient``.
"""
from office365.graph_client import AzureEnvironment # type: ignore[import-untyped]
from pydantic import BaseModel
from onyx.connectors.exceptions import ConnectorValidationError
class MicrosoftGraphEnvironment(BaseModel):
"""One row of the inverse mapping."""
environment: str
graph_host: str
authority_host: str
sharepoint_domain_suffix: str
_ENVIRONMENTS: list[MicrosoftGraphEnvironment] = [
MicrosoftGraphEnvironment(
environment=AzureEnvironment.Global,
graph_host="https://graph.microsoft.com",
authority_host="https://login.microsoftonline.com",
sharepoint_domain_suffix="sharepoint.com",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.USGovernmentHigh,
graph_host="https://graph.microsoft.us",
authority_host="https://login.microsoftonline.us",
sharepoint_domain_suffix="sharepoint.us",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.USGovernmentDoD,
graph_host="https://dod-graph.microsoft.us",
authority_host="https://login.microsoftonline.us",
sharepoint_domain_suffix="sharepoint.us",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.China,
graph_host="https://microsoftgraph.chinacloudapi.cn",
authority_host="https://login.chinacloudapi.cn",
sharepoint_domain_suffix="sharepoint.cn",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.Germany,
graph_host="https://graph.microsoft.de",
authority_host="https://login.microsoftonline.de",
sharepoint_domain_suffix="sharepoint.de",
),
]
_GRAPH_HOST_INDEX: dict[str, MicrosoftGraphEnvironment] = {
env.graph_host: env for env in _ENVIRONMENTS
}
def resolve_microsoft_environment(
graph_api_host: str,
authority_host: str,
) -> MicrosoftGraphEnvironment:
"""Return the ``MicrosoftGraphEnvironment`` that matches the supplied hosts.
Raises ``ConnectorValidationError`` when the combination is unknown or
internally inconsistent (e.g. a GCC-High graph host paired with a
commercial authority host).
"""
graph_api_host = graph_api_host.rstrip("/")
authority_host = authority_host.rstrip("/")
env = _GRAPH_HOST_INDEX.get(graph_api_host)
if env is None:
known = ", ".join(sorted(_GRAPH_HOST_INDEX))
raise ConnectorValidationError(
f"Unsupported Microsoft Graph API host '{graph_api_host}'. "
f"Recognised hosts: {known}"
)
if env.authority_host != authority_host:
raise ConnectorValidationError(
f"Authority host '{authority_host}' is inconsistent with "
f"graph API host '{graph_api_host}'. "
f"Expected authority host '{env.authority_host}' "
f"for the {env.environment} environment."
)
return env

View File

@@ -6,6 +6,7 @@ from typing import cast
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from onyx.access.models import ExternalAccess
@@ -167,6 +168,14 @@ class DocumentBase(BaseModel):
# list of strings.
metadata: dict[str, str | list[str]]
@field_validator("metadata", mode="before")
@classmethod
def _coerce_metadata_values(cls, v: dict[str, Any]) -> dict[str, str | list[str]]:
return {
key: [str(item) for item in val] if isinstance(val, list) else str(val)
for key, val in v.items()
}
# UTC time
doc_updated_at: datetime | None = None
chunk_count: int | None = None

View File

@@ -47,6 +47,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import IndexingHeartbeatInterface
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
@@ -146,7 +147,9 @@ class DriveItemData(BaseModel):
self.id,
ResourcePath("items", ResourcePath(self.drive_id, ResourcePath("drives"))),
)
return DriveItem(graph_client, path)
item = DriveItem(graph_client, path)
item.set_property("id", self.id)
return item
# The office365 library's ClientContext caches the access token from its
@@ -837,10 +840,20 @@ class SharepointConnector(
self._cached_rest_ctx: ClientContext | None = None
self._cached_rest_ctx_url: str | None = None
self._cached_rest_ctx_created_at: float = 0.0
self.authority_host = authority_host.rstrip("/")
self.graph_api_host = graph_api_host.rstrip("/")
resolved_env = resolve_microsoft_environment(graph_api_host, authority_host)
self._azure_environment = resolved_env.environment
self.authority_host = resolved_env.authority_host
self.graph_api_host = resolved_env.graph_host
self.graph_api_base = f"{self.graph_api_host}/v1.0"
self.sharepoint_domain_suffix = sharepoint_domain_suffix
self.sharepoint_domain_suffix = resolved_env.sharepoint_domain_suffix
if sharepoint_domain_suffix != resolved_env.sharepoint_domain_suffix:
logger.warning(
f"Configured sharepoint_domain_suffix '{sharepoint_domain_suffix}' "
f"differs from the expected suffix '{resolved_env.sharepoint_domain_suffix}' "
f"for the {resolved_env.environment} environment. "
f"Using '{resolved_env.sharepoint_domain_suffix}'."
)
def validate_connector_settings(self) -> None:
# Validate that at least one content type is enabled
@@ -1592,6 +1605,7 @@ class SharepointConnector(
if certificate_data is None:
raise RuntimeError("Failed to load certificate")
logger.info(f"Creating MSAL app with authority url {authority_url}")
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=sp_client_id,
@@ -1623,7 +1637,9 @@ class SharepointConnector(
raise ConnectorValidationError("Failed to acquire token for graph")
return token
self._graph_client = GraphClient(_acquire_token_for_graph)
self._graph_client = GraphClient(
_acquire_token_for_graph, environment=self._azure_environment
)
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
org = self.graph_client.organization.get().execute_query()
if not org or len(org) == 0:

View File

@@ -11,6 +11,7 @@ from dateutil import parser
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
@@ -258,3 +259,21 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
slim_doc_batch = []
if slim_doc_batch:
yield slim_doc_batch
def validate_connector_settings(self) -> None:
"""
Very basic validation, we could do more here
"""
if not self.base_url.startswith("https://") and not self.base_url.startswith(
"http://"
):
raise ConnectorValidationError(
"Base URL must start with https:// or http://"
)
try:
get_all_post_ids(self.slab_bot_token)
except ConnectorMissingCredentialError:
raise
except Exception as e:
raise ConnectorValidationError(f"Failed to fetch posts from Slab: {e}")

View File

@@ -23,6 +23,7 @@ from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
@@ -73,8 +74,11 @@ class TeamsConnector(
self.msal_app: msal.ConfidentialClientApplication | None = None
self.max_workers = max_workers
self.requested_team_list: list[str] = teams
self.authority_host = authority_host.rstrip("/")
self.graph_api_host = graph_api_host.rstrip("/")
resolved_env = resolve_microsoft_environment(graph_api_host, authority_host)
self._azure_environment = resolved_env.environment
self.authority_host = resolved_env.authority_host
self.graph_api_host = resolved_env.graph_host
# impls for BaseConnector
@@ -106,7 +110,9 @@ class TeamsConnector(
return token
self.graph_client = GraphClient(_acquire_token_func)
self.graph_client = GraphClient(
_acquire_token_func, environment=self._azure_environment
)
return None
def validate_connector_settings(self) -> None:

View File

@@ -59,12 +59,11 @@ def _build_index_filters(
base_filters = user_provided_filters or BaseFilters()
if (
user_provided_filters
and user_provided_filters.document_set is None
and persona_document_sets is not None
):
base_filters.document_set = persona_document_sets
document_set_filter = (
base_filters.document_set
if base_filters.document_set is not None
else persona_document_sets
)
time_filter = base_filters.time_cutoff or persona_time_cutoff
source_filter = base_filters.source_type
@@ -120,7 +119,7 @@ def _build_index_filters(
user_file_ids=user_file_ids,
project_id=project_id,
source_type=source_filter,
document_set=persona_document_sets,
document_set=document_set_filter,
time_cutoff=time_filter,
tags=base_filters.tags,
access_control_list=user_acl_filters,

View File

@@ -0,0 +1,21 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import CodeInterpreterServer
def fetch_code_interpreter_server(
db_session: Session,
) -> CodeInterpreterServer:
server = db_session.scalars(select(CodeInterpreterServer)).one()
return server
def update_code_interpreter_server_enabled(
db_session: Session,
enabled: bool,
) -> CodeInterpreterServer:
server = db_session.scalars(select(CodeInterpreterServer)).one()
server.server_enabled = enabled
db_session.commit()
return server

View File

@@ -1,11 +1,102 @@
from sqlalchemy import text
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import SqlEngine
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
def get_schemas_needing_migration(
tenant_schemas: list[str], head_rev: str
) -> list[str]:
"""Return only schemas whose current alembic version is not at head.
Uses a server-side PL/pgSQL loop to collect each schema's alembic version
into a temp table one at a time. This avoids building a massive UNION ALL
query (which locks the DB and times out at 17k+ schemas) and instead
acquires locks sequentially, one schema per iteration.
"""
if not tenant_schemas:
return []
engine = SqlEngine.get_engine()
with engine.connect() as conn:
# Populate a temp input table with exactly the schemas we care about.
# The DO block reads from this table so it only iterates the requested
# schemas instead of every tenant_% schema in the database.
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
conn.execute(text("CREATE TEMP TABLE _tenant_schemas_input (schema_name text)"))
conn.execute(
text(
"INSERT INTO _tenant_schemas_input (schema_name) "
"SELECT unnest(CAST(:schemas AS text[]))"
),
{"schemas": tenant_schemas},
)
conn.execute(
text(
"CREATE TEMP TABLE _alembic_version_snapshot "
"(schema_name text, version_num text)"
)
)
conn.execute(
text(
"""
DO $$
DECLARE
s text;
schemas text[];
BEGIN
SELECT array_agg(schema_name) INTO schemas
FROM _tenant_schemas_input;
IF schemas IS NULL THEN
RAISE NOTICE 'No tenant schemas found.';
RETURN;
END IF;
FOREACH s IN ARRAY schemas LOOP
BEGIN
EXECUTE format(
'INSERT INTO _alembic_version_snapshot
SELECT %L, version_num FROM %I.alembic_version',
s, s
);
EXCEPTION
-- undefined_table: schema exists but has no alembic_version
-- table yet (new tenant, not yet migrated).
-- invalid_schema_name: tenant is registered but its
-- PostgreSQL schema does not exist yet (e.g. provisioning
-- incomplete). Both cases mean no version is available and
-- the schema will be included in the migration list.
WHEN undefined_table THEN NULL;
WHEN invalid_schema_name THEN NULL;
END;
END LOOP;
END;
$$
"""
)
)
rows = conn.execute(
text("SELECT schema_name, version_num FROM _alembic_version_snapshot")
)
version_by_schema = {row[0]: row[1] for row in rows}
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
# Schemas missing from the snapshot have no alembic_version table yet and
# also need migration. version_by_schema.get(s) returns None for those,
# and None != head_rev, so they are included automatically.
return [s for s in tenant_schemas if version_by_schema.get(s) != head_rev]
def get_all_tenant_ids() -> list[str]:
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""

View File

@@ -4940,6 +4940,11 @@ class ScimUserMapping(Base):
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
scim_username: Mapped[str | None] = mapped_column(String, nullable=True)
department: Mapped[str | None] = mapped_column(String, nullable=True)
manager: Mapped[str | None] = mapped_column(String, nullable=True)
given_name: Mapped[str | None] = mapped_column(String, nullable=True)
family_name: Mapped[str | None] = mapped_column(String, nullable=True)
scim_emails_json: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
@@ -4978,3 +4983,12 @@ class ScimGroupMapping(Base):
user_group: Mapped[UserGroup] = relationship(
"UserGroup", foreign_keys=[user_group_id]
)
class CodeInterpreterServer(Base):
"""Details about the code interpreter server"""
__tablename__ = "code_interpreter_server"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)

View File

@@ -2,6 +2,7 @@ import random
from datetime import datetime
from datetime import timedelta
from logging import getLogger
from uuid import UUID
from onyx.configs.constants import MessageType
from onyx.db.chat import create_chat_session
@@ -13,18 +14,26 @@ from onyx.db.models import ChatSession
logger = getLogger(__name__)
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
def seed_chat_history(
num_sessions: int,
num_messages: int,
days: int,
user_id: UUID | None = None,
persona_id: int | None = None,
) -> None:
"""Utility function to seed chat history for testing.
num_sessions: the number of sessions to seed
num_messages: the number of messages to seed per sessions
days: the number of days looking backwards from the current time over which to randomize
the times.
user_id: optional user to associate with sessions
persona_id: optional persona/assistant to associate with sessions
"""
with get_session_with_current_tenant() as db_session:
logger.info(f"Seeding {num_sessions} sessions.")
for y in range(0, num_sessions):
create_chat_session(db_session, f"pytest_session_{y}", None, None)
create_chat_session(db_session, f"pytest_session_{y}", user_id, persona_id)
# randomize all session times
logger.info(f"Seeding {num_messages} messages per session.")

View File

@@ -121,6 +121,7 @@ class VespaDocumentUserFields:
"""
user_projects: list[int] | None = None
personas: list[int] | None = None
@dataclass

View File

@@ -148,6 +148,7 @@ class MetadataUpdateRequest(BaseModel):
hidden: bool | None = None
secondary_index_updated: bool | None = None
project_ids: set[int] | None = None
persona_ids: set[int] | None = None
class IndexRetrievalFilters(BaseModel):

View File

@@ -50,6 +50,7 @@ from onyx.document_index.opensearch.schema import DocumentSchema
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
from onyx.document_index.opensearch.search import DocumentQuery
from onyx.document_index.opensearch.search import (
@@ -215,6 +216,7 @@ def _convert_onyx_chunk_to_opensearch_document(
# OpenSearch and it will not store any data at all for this field, which
# is different from supplying an empty list.
user_projects=chunk.user_project or None,
personas=chunk.personas or None,
primary_owners=get_experts_stores_representations(
chunk.source_document.primary_owners
),
@@ -362,6 +364,11 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
if user_fields and user_fields.user_projects
else None
),
persona_ids=(
set(user_fields.personas)
if user_fields and user_fields.personas
else None
),
)
try:
@@ -709,6 +716,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
properties_to_update[USER_PROJECTS_FIELD_NAME] = list(
update_request.project_ids
)
if update_request.persona_ids is not None:
properties_to_update[PERSONAS_FIELD_NAME] = list(
update_request.persona_ids
)
if not properties_to_update:
if len(update_request.document_ids) > 1:

View File

@@ -41,6 +41,7 @@ IMAGE_FILE_ID_FIELD_NAME = "image_file_id"
SOURCE_LINKS_FIELD_NAME = "source_links"
DOCUMENT_SETS_FIELD_NAME = "document_sets"
USER_PROJECTS_FIELD_NAME = "user_projects"
PERSONAS_FIELD_NAME = "personas"
DOCUMENT_ID_FIELD_NAME = "document_id"
CHUNK_INDEX_FIELD_NAME = "chunk_index"
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
@@ -156,6 +157,7 @@ class DocumentChunk(BaseModel):
document_sets: list[str] | None = None
user_projects: list[int] | None = None
personas: list[int] | None = None
primary_owners: list[str] | None = None
secondary_owners: list[str] | None = None
@@ -485,6 +487,7 @@ class DocumentSchema:
# Product-specific fields.
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
USER_PROJECTS_FIELD_NAME: {"type": "integer"},
PERSONAS_FIELD_NAME: {"type": "integer"},
PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"},
SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"},
# OpenSearch metadata fields.

View File

@@ -181,6 +181,11 @@ schema {{ schema_name }} {
rank: filter
attribute: fast-search
}
field personas type array<int> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
}
# If using different tokenization settings, the fieldset has to be removed, and the field must

View File

@@ -689,6 +689,9 @@ class VespaIndex(DocumentIndex):
project_ids: set[int] | None = None
if user_fields is not None and user_fields.user_projects is not None:
project_ids = set(user_fields.user_projects)
persona_ids: set[int] | None = None
if user_fields is not None and user_fields.personas is not None:
persona_ids = set(user_fields.personas)
update_request = MetadataUpdateRequest(
document_ids=[doc_id],
doc_id_to_chunk_cnt={
@@ -699,6 +702,7 @@ class VespaIndex(DocumentIndex):
boost=fields.boost if fields is not None else None,
hidden=fields.hidden if fields is not None else None,
project_ids=project_ids,
persona_ids=persona_ids,
)
vespa_document_index.update([update_request])

View File

@@ -46,6 +46,7 @@ from onyx.document_index.vespa_constants import METADATA
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import METADATA_SUFFIX
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
from onyx.document_index.vespa_constants import SECTION_CONTINUATION
@@ -218,6 +219,7 @@ def _index_vespa_chunk(
# still called `image_file_name` in Vespa for backwards compatibility
IMAGE_FILE_NAME: chunk.image_file_id,
USER_PROJECT: chunk.user_project if chunk.user_project is not None else [],
PERSONAS: chunk.personas if chunk.personas is not None else [],
BOOST: chunk.boost,
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
}

View File

@@ -183,6 +183,10 @@ def _update_single_chunk(
model_config = {"frozen": True}
assign: list[int]
class _Personas(BaseModel):
model_config = {"frozen": True}
assign: list[int]
class _VespaPutFields(BaseModel):
model_config = {"frozen": True}
# The names of these fields are based the Vespa schema. Changes to the
@@ -193,6 +197,7 @@ def _update_single_chunk(
access_control_list: _AccessControl | None = None
hidden: _Hidden | None = None
user_project: _UserProjects | None = None
personas: _Personas | None = None
class _VespaPutRequest(BaseModel):
model_config = {"frozen": True}
@@ -227,6 +232,11 @@ def _update_single_chunk(
if update_request.project_ids is not None
else None
)
personas_update: _Personas | None = (
_Personas(assign=list(update_request.persona_ids))
if update_request.persona_ids is not None
else None
)
vespa_put_fields = _VespaPutFields(
boost=boost_update,
@@ -234,6 +244,7 @@ def _update_single_chunk(
access_control_list=access_update,
hidden=hidden_update,
user_project=user_projects_update,
personas=personas_update,
)
vespa_put_request = _VespaPutRequest(
@@ -554,10 +565,9 @@ class VespaDocumentIndex(DocumentIndex):
num_to_retrieve: int,
) -> list[InferenceChunk]:
vespa_where_clauses = build_vespa_filters(filters)
# Needs to be at least as much as the rerank-count value set in the
# Vespa schema config. Otherwise we would be getting fewer results than
# expected for reranking.
target_hits = max(10 * num_to_retrieve, RERANK_COUNT)
# Avoid over-fetching a very large candidate set for global-phase reranking.
# Keep enough headroom for quality while capping cost on larger indices.
target_hits = min(max(4 * num_to_retrieve, 100), RERANK_COUNT)
yql = (
YQL_BASE.format(index_name=self._index_name)

View File

@@ -58,6 +58,7 @@ DOCUMENT_SETS = "document_sets"
USER_FILE = "user_file"
USER_FOLDER = "user_folder"
USER_PROJECT = "user_project"
PERSONAS = "personas"
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
METADATA = "metadata"
METADATA_LIST = "metadata_list"

View File

@@ -12,6 +12,9 @@ if TYPE_CHECKING:
class AzureImageGenerationProvider(ImageGenerationProvider):
_GPT_IMAGE_MODEL_PREFIX = "gpt-image-"
_DALL_E_2_MODEL_NAME = "dall-e-2"
def __init__(
self,
api_key: str,
@@ -53,6 +56,25 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
deployment_name=credentials.deployment_name,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# Azure GPT image models support up to 16 input images for edits.
return 16
def _normalize_model_name(self, model: str) -> str:
return model.rsplit("/", 1)[-1]
def _model_supports_image_edits(self, model: str) -> bool:
normalized_model = self._normalize_model_name(model)
return (
normalized_model.startswith(self._GPT_IMAGE_MODEL_PREFIX)
or normalized_model == self._DALL_E_2_MODEL_NAME
)
def generate_image(
self,
prompt: str,
@@ -60,14 +82,44 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation
deployment = self._deployment_name or model
model_name = f"azure/{deployment}"
if reference_images:
if not self._model_supports_image_edits(model):
raise ValueError(
f"Model '{model}' does not support image edits with reference images."
)
normalized_model = self._normalize_model_name(model)
if (
normalized_model == self._DALL_E_2_MODEL_NAME
and len(reference_images) > 1
):
raise ValueError(
"Model 'dall-e-2' only supports a single reference image for edits."
)
from litellm import image_edit
return image_edit(
image=[image.data for image in reference_images],
prompt=prompt,
model=model_name,
api_key=self._api_key,
api_base=self._api_base,
api_version=self._api_version,
size=size,
n=n,
quality=quality,
**kwargs,
)
from litellm import image_generation
return image_generation(
prompt=prompt,
model=model_name,

View File

@@ -12,6 +12,9 @@ if TYPE_CHECKING:
class OpenAIImageGenerationProvider(ImageGenerationProvider):
_GPT_IMAGE_MODEL_PREFIX = "gpt-image-"
_DALL_E_2_MODEL_NAME = "dall-e-2"
def __init__(
self,
api_key: str,
@@ -39,6 +42,25 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
api_base=credentials.api_base,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# GPT image models support up to 16 input images for edits.
return 16
def _normalize_model_name(self, model: str) -> str:
return model.rsplit("/", 1)[-1]
def _model_supports_image_edits(self, model: str) -> bool:
normalized_model = self._normalize_model_name(model)
return (
normalized_model.startswith(self._GPT_IMAGE_MODEL_PREFIX)
or normalized_model == self._DALL_E_2_MODEL_NAME
)
def generate_image(
self,
prompt: str,
@@ -46,9 +68,38 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
if reference_images:
if not self._model_supports_image_edits(model):
raise ValueError(
f"Model '{model}' does not support image edits with reference images."
)
normalized_model = self._normalize_model_name(model)
if (
normalized_model == self._DALL_E_2_MODEL_NAME
and len(reference_images) > 1
):
raise ValueError(
"Model 'dall-e-2' only supports a single reference image for edits."
)
from litellm import image_edit
return image_edit(
image=[image.data for image in reference_images],
prompt=prompt,
model=model,
api_key=self._api_key,
api_base=self._api_base,
size=size,
n=n,
quality=quality,
**kwargs,
)
from litellm import image_generation
return image_generation(

View File

@@ -146,6 +146,7 @@ class DocumentIndexingBatchAdapter:
doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
context.id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in context.id_to_boost_map

View File

@@ -182,7 +182,7 @@ class UserFileIndexingAdapter:
user_project=user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
# we are going to index userfiles only once, so we just set the boost to the default
personas=[],
boost=DEFAULT_BOOST,
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],

View File

@@ -49,6 +49,7 @@ from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import IndexingBatchAdapter
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.factory import get_llm_for_contextual_rag
@@ -228,6 +229,8 @@ def index_doc_batch_prepare(
) -> DocumentBatchPrepareContext | None:
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
This preceeds indexing it into the actual document index."""
documents = sanitize_documents_for_postgres(documents)
# Create a trimmed list of docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
document_ids: list[str] = [document.id for document in documents]

View File

@@ -112,6 +112,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess"
document_sets: set[str]
user_project: list[int]
personas: list[int]
boost: int
aggregated_chunk_boost_factor: float
# Full ancestor path from root hierarchy node to document's parent.
@@ -126,6 +127,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess",
document_sets: set[str],
user_project: list[int],
personas: list[int],
boost: int,
aggregated_chunk_boost_factor: float,
tenant_id: str,
@@ -137,6 +139,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access=access,
document_sets=document_sets,
user_project=user_project,
personas=personas,
boost=boost,
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
tenant_id=tenant_id,

View File

@@ -0,0 +1,150 @@
from typing import Any
from onyx.access.models import ExternalAccess
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
def _sanitize_string(value: str) -> str:
return value.replace("\x00", "")
def _sanitize_json_like(value: Any) -> Any:
if isinstance(value, str):
return _sanitize_string(value)
if isinstance(value, list):
return [_sanitize_json_like(item) for item in value]
if isinstance(value, tuple):
return tuple(_sanitize_json_like(item) for item in value)
if isinstance(value, dict):
sanitized: dict[Any, Any] = {}
for key, nested_value in value.items():
cleaned_key = _sanitize_string(key) if isinstance(key, str) else key
sanitized[cleaned_key] = _sanitize_json_like(nested_value)
return sanitized
return value
def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
return expert.model_copy(
update={
"display_name": (
_sanitize_string(expert.display_name)
if expert.display_name is not None
else None
),
"first_name": (
_sanitize_string(expert.first_name)
if expert.first_name is not None
else None
),
"middle_initial": (
_sanitize_string(expert.middle_initial)
if expert.middle_initial is not None
else None
),
"last_name": (
_sanitize_string(expert.last_name)
if expert.last_name is not None
else None
),
"email": (
_sanitize_string(expert.email) if expert.email is not None else None
),
}
)
def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess:
return ExternalAccess(
external_user_emails={
_sanitize_string(email) for email in external_access.external_user_emails
},
external_user_group_ids={
_sanitize_string(group_id)
for group_id in external_access.external_user_group_ids
},
is_public=external_access.is_public,
)
def sanitize_document_for_postgres(document: Document) -> Document:
cleaned_doc = document.model_copy(deep=True)
cleaned_doc.id = _sanitize_string(cleaned_doc.id)
cleaned_doc.semantic_identifier = _sanitize_string(cleaned_doc.semantic_identifier)
if cleaned_doc.title is not None:
cleaned_doc.title = _sanitize_string(cleaned_doc.title)
if cleaned_doc.parent_hierarchy_raw_node_id is not None:
cleaned_doc.parent_hierarchy_raw_node_id = _sanitize_string(
cleaned_doc.parent_hierarchy_raw_node_id
)
cleaned_doc.metadata = {
_sanitize_string(key): (
[_sanitize_string(item) for item in value]
if isinstance(value, list)
else _sanitize_string(value)
)
for key, value in cleaned_doc.metadata.items()
}
if cleaned_doc.doc_metadata is not None:
cleaned_doc.doc_metadata = _sanitize_json_like(cleaned_doc.doc_metadata)
if cleaned_doc.primary_owners is not None:
cleaned_doc.primary_owners = [
_sanitize_expert_info(expert) for expert in cleaned_doc.primary_owners
]
if cleaned_doc.secondary_owners is not None:
cleaned_doc.secondary_owners = [
_sanitize_expert_info(expert) for expert in cleaned_doc.secondary_owners
]
if cleaned_doc.external_access is not None:
cleaned_doc.external_access = _sanitize_external_access(
cleaned_doc.external_access
)
for section in cleaned_doc.sections:
if section.link is not None:
section.link = _sanitize_string(section.link)
if section.text is not None:
section.text = _sanitize_string(section.text)
if section.image_file_id is not None:
section.image_file_id = _sanitize_string(section.image_file_id)
return cleaned_doc
def sanitize_documents_for_postgres(documents: list[Document]) -> list[Document]:
return [sanitize_document_for_postgres(document) for document in documents]
def sanitize_hierarchy_node_for_postgres(node: HierarchyNode) -> HierarchyNode:
cleaned_node = node.model_copy(deep=True)
cleaned_node.raw_node_id = _sanitize_string(cleaned_node.raw_node_id)
cleaned_node.display_name = _sanitize_string(cleaned_node.display_name)
if cleaned_node.raw_parent_id is not None:
cleaned_node.raw_parent_id = _sanitize_string(cleaned_node.raw_parent_id)
if cleaned_node.link is not None:
cleaned_node.link = _sanitize_string(cleaned_node.link)
if cleaned_node.external_access is not None:
cleaned_node.external_access = _sanitize_external_access(
cleaned_node.external_access
)
return cleaned_node
def sanitize_hierarchy_nodes_for_postgres(
nodes: list[HierarchyNode],
) -> list[HierarchyNode]:
return [sanitize_hierarchy_node_for_postgres(node) for node in nodes]

View File

@@ -97,6 +97,9 @@ from onyx.server.features.web_search.api import router as web_search_router
from onyx.server.federated.api import router as federated_router
from onyx.server.kg.api import admin_router as kg_admin_router
from onyx.server.manage.administrative import router as admin_router
from onyx.server.manage.code_interpreter.api import (
admin_router as code_interpreter_admin_router,
)
from onyx.server.manage.discord_bot.api import router as discord_bot_router
from onyx.server.manage.embedding.api import admin_router as embedding_admin_router
from onyx.server.manage.embedding.api import basic_router as embedding_router
@@ -421,6 +424,9 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, llm_admin_router)
include_router_with_global_prefix_prepended(application, kg_admin_router)
include_router_with_global_prefix_prepended(application, llm_router)
include_router_with_global_prefix_prepended(
application, code_interpreter_admin_router
)
include_router_with_global_prefix_prepended(
application, image_generation_admin_router
)

View File

@@ -1,14 +1,68 @@
import re
from typing import Any
from mistune import create_markdown
from mistune import HTMLRenderer
_CITATION_LINK_PATTERN = re.compile(r"\[\[\d+\]\]\(")
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
"""Extract markdown link destination, allowing nested parentheses in the URL."""
depth = 0
i = start_idx
while i < len(message):
curr = message[i]
if curr == "\\":
i += 2
continue
if curr == "(":
depth += 1
elif curr == ")":
if depth == 0:
return message[start_idx:i], i
depth -= 1
i += 1
return message[start_idx:], None
def _normalize_citation_link_destinations(message: str) -> str:
"""Wrap citation URLs in angle brackets so markdown parsers handle parentheses safely."""
if "[[" not in message:
return message
normalized_parts: list[str] = []
cursor = 0
while match := _CITATION_LINK_PATTERN.search(message, cursor):
normalized_parts.append(message[cursor : match.end()])
destination_start = match.end()
destination, end_idx = _extract_link_destination(message, destination_start)
if end_idx is None:
normalized_parts.append(message[destination_start:])
return "".join(normalized_parts)
already_wrapped = destination.startswith("<") and destination.endswith(">")
if destination and not already_wrapped:
destination = f"<{destination}>"
normalized_parts.append(destination)
normalized_parts.append(")")
cursor = end_idx + 1
normalized_parts.append(message[cursor:])
return "".join(normalized_parts)
def format_slack_message(message: str | None) -> str:
if message is None:
return ""
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
result = md(message)
normalized_message = _normalize_citation_link_destinations(message)
result = md(normalized_message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
return result

View File

@@ -762,6 +762,43 @@ def download_webapp(
)
@router.get("/{session_id}/download-directory/{path:path}")
def download_directory(
session_id: UUID,
path: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Response:
"""
Download a directory as a zip file.
Returns the specified directory as a zip archive.
"""
user_id: UUID = user.id
session_manager = SessionManager(db_session)
try:
result = session_manager.download_directory(session_id, user_id, path)
except ValueError as e:
error_message = str(e)
if "path traversal" in error_message.lower():
raise HTTPException(status_code=403, detail="Access denied")
raise HTTPException(status_code=400, detail=error_message)
if result is None:
raise HTTPException(status_code=404, detail="Directory not found")
zip_bytes, filename = result
return Response(
content=zip_bytes,
media_type="application/zip",
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
},
)
@router.post("/{session_id}/upload", response_model=UploadResponse)
def upload_file_endpoint(
session_id: UUID,

View File

@@ -107,27 +107,23 @@ def get_or_create_craft_connector(db_session: Session, user: User) -> tuple[int,
)
for cc_pair in cc_pairs:
if cc_pair.connector.source == DocumentSource.CRAFT_FILE:
if (
cc_pair.connector.source == DocumentSource.CRAFT_FILE
and cc_pair.creator_id == user.id
):
return cc_pair.connector.id, cc_pair.credential.id
# Check for orphaned connector (created but cc_pair creation failed previously)
# No cc_pair for this user — find or create the shared CRAFT_FILE connector
existing_connectors = fetch_connectors(
db_session, sources=[DocumentSource.CRAFT_FILE]
)
orphaned_connector = None
connector_id: int | None = None
for conn in existing_connectors:
if conn.name != USER_LIBRARY_CONNECTOR_NAME:
continue
if not conn.credentials:
orphaned_connector = conn
if conn.name == USER_LIBRARY_CONNECTOR_NAME:
connector_id = conn.id
break
if orphaned_connector:
connector_id = orphaned_connector.id
logger.info(
f"Found orphaned User Library connector {connector_id}, completing setup"
)
else:
if connector_id is None:
connector_data = ConnectorBase(
name=USER_LIBRARY_CONNECTOR_NAME,
source=DocumentSource.CRAFT_FILE,

View File

@@ -1,15 +1,19 @@
#!/usr/bin/env python3
"""Generate AGENTS.md by scanning the files directory and populating the template.
This script runs at container startup, AFTER the init container has synced files
from S3. It scans the /workspace/files directory to discover what knowledge sources
are available and generates appropriate documentation.
This script runs during session setup, AFTER files have been synced from S3
and the files symlink has been created. It reads the template from stdin,
replaces the {{KNOWLEDGE_SOURCES_SECTION}} placeholder by scanning the
knowledge source directory, and writes the final AGENTS.md to the output path.
Environment variables:
- AGENT_INSTRUCTIONS: The template content with placeholders to replace
Usage:
printf '%s' "$TEMPLATE" | python3 generate_agents_md.py <output_path> <files_path>
Arguments:
output_path: Path to write the final AGENTS.md
files_path: Path to the files directory to scan for knowledge sources
"""
import os
import sys
from pathlib import Path
@@ -189,49 +193,39 @@ def build_knowledge_sources_section(files_path: Path) -> str:
def main() -> None:
"""Main entry point for container startup script.
Is called by the container startup script to scan /workspace/files and populate
the knowledge sources section.
Reads the template from stdin, replaces the {{KNOWLEDGE_SOURCES_SECTION}}
placeholder by scanning the files directory, and writes the result.
Usage:
printf '%s' "$TEMPLATE" | python3 generate_agents_md.py <output_path> <files_path>
"""
# Read template from environment variable
template = os.environ.get("AGENT_INSTRUCTIONS", "")
if len(sys.argv) != 3:
print(
f"Usage: {sys.argv[0]} <output_path> <files_path>",
file=sys.stderr,
)
sys.exit(1)
output_path = Path(sys.argv[1])
files_path = Path(sys.argv[2])
# Read template from stdin
template = sys.stdin.read()
if not template:
print("Warning: No AGENT_INSTRUCTIONS template provided", file=sys.stderr)
template = "# Agent Instructions\n\nNo instructions provided."
print("Error: No template content provided on stdin", file=sys.stderr)
sys.exit(1)
# Scan files directory - check /workspace/files first, then /workspace/demo_data
files_path = Path("/workspace/files")
demo_data_path = Path("/workspace/demo_data")
# Resolve symlinks (handles both direct symlinks and dirs containing symlinks)
resolved_files_path = files_path.resolve()
# Use demo_data if files doesn't exist or is empty
if not files_path.exists() or not any(files_path.iterdir()):
if demo_data_path.exists():
files_path = demo_data_path
knowledge_sources_section = build_knowledge_sources_section(resolved_files_path)
knowledge_sources_section = build_knowledge_sources_section(files_path)
# Replace placeholders
content = template
content = content.replace(
# Replace placeholder and write final file
content = template.replace(
"{{KNOWLEDGE_SOURCES_SECTION}}", knowledge_sources_section
)
# Write AGENTS.md
output_path = Path("/workspace/AGENTS.md")
output_path.write_text(content)
# Log result
source_count = 0
if files_path.exists():
source_count = len(
[
d
for d in files_path.iterdir()
if d.is_dir() and not d.name.startswith(".")
]
)
print(
f"Generated AGENTS.md with {source_count} knowledge sources from {files_path}"
)
print(f"Generated {output_path} (scanned {resolved_files_path})")
if __name__ == "__main__":

View File

@@ -1348,9 +1348,10 @@ if [ -d /workspace/skills ]; then
echo "Linked skills to /workspace/skills"
fi
# Write agent instructions
# Write agent instructions (scans files dir to populate knowledge sources)
echo "Writing AGENTS.md"
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
printf '%s' '{agent_instructions_escaped}' \
| python3 /usr/local/bin/generate_agents_md.py {session_path}/AGENTS.md {session_path}/files
# Write opencode config
echo "Writing opencode.json"
@@ -1776,9 +1777,11 @@ set -e
echo "Creating files symlink to {symlink_target}"
ln -sf {symlink_target} {session_path}/files
# Write agent instructions
# Write agent instructions (scans files dir to populate knowledge sources)
echo "Writing AGENTS.md"
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
printf '%s' '{agent_instructions_escaped}' \
| python3 /usr/local/bin/generate_agents_md.py \
{session_path}/AGENTS.md {session_path}/files
# Write opencode config
echo "Writing opencode.json"

View File

@@ -68,6 +68,7 @@ from onyx.server.features.build.db.sandbox import create_sandbox__no_commit
from onyx.server.features.build.db.sandbox import get_running_sandbox_count_by_tenant
from onyx.server.features.build.db.sandbox import get_sandbox_by_session_id
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
from onyx.server.features.build.db.sandbox import get_snapshots_for_session
from onyx.server.features.build.db.sandbox import update_sandbox_heartbeat
from onyx.server.features.build.db.sandbox import update_sandbox_status__no_commit
from onyx.server.features.build.sandbox import get_sandbox_manager
@@ -646,16 +647,30 @@ class SessionManager:
if sandbox and sandbox.status.is_active():
# Quick health check to verify sandbox is actually responsive
if self._sandbox_manager.health_check(sandbox.id, timeout=5.0):
# AND verify the session workspace still exists on disk
# (it may have been wiped if the sandbox was re-provisioned)
is_healthy = self._sandbox_manager.health_check(sandbox.id, timeout=5.0)
workspace_exists = (
is_healthy
and self._sandbox_manager.session_workspace_exists(
sandbox.id, existing.id
)
)
if is_healthy and workspace_exists:
logger.info(
f"Returning existing empty session {existing.id} for user {user_id}"
)
return existing
else:
elif not is_healthy:
logger.warning(
f"Empty session {existing.id} has unhealthy sandbox {sandbox.id}. "
f"Deleting and creating fresh session."
)
else:
logger.warning(
f"Empty session {existing.id} workspace missing in sandbox "
f"{sandbox.id}. Deleting and creating fresh session."
)
else:
logger.warning(
f"Empty session {existing.id} has no active sandbox "
@@ -1035,6 +1050,23 @@ class SessionManager:
# workspace cleanup fails (e.g., if pod is already terminated)
logger.warning(f"Failed to cleanup session workspace {session_id}: {e}")
# Delete snapshot files from S3 before removing DB records
snapshots = get_snapshots_for_session(self._db_session, session_id)
if snapshots:
from onyx.file_store.file_store import get_default_file_store
from onyx.server.features.build.sandbox.manager.snapshot_manager import (
SnapshotManager,
)
snapshot_manager = SnapshotManager(get_default_file_store())
for snapshot in snapshots:
try:
snapshot_manager.delete_snapshot(snapshot.storage_path)
except Exception as e:
logger.warning(
f"Failed to delete snapshot file {snapshot.storage_path}: {e}"
)
# Delete session (uses flush, caller commits)
return delete_build_session__no_commit(session_id, user_id, self._db_session)
@@ -1903,6 +1935,94 @@ class SessionManager:
return zip_buffer.getvalue(), filename
def download_directory(
self,
session_id: UUID,
user_id: UUID,
path: str,
) -> tuple[bytes, str] | None:
"""
Create a zip file of an arbitrary directory in the session workspace.
Args:
session_id: The session UUID
user_id: The user ID to verify ownership
path: Relative path to the directory (within session workspace)
Returns:
Tuple of (zip_bytes, filename) or None if session not found
Raises:
ValueError: If path traversal attempted or path is not a directory
"""
# Verify session ownership
session = get_build_session(session_id, user_id, self._db_session)
if session is None:
return None
sandbox = get_sandbox_by_user_id(self._db_session, user_id)
if sandbox is None:
return None
# Check if directory exists
try:
self._sandbox_manager.list_directory(
sandbox_id=sandbox.id,
session_id=session_id,
path=path,
)
except ValueError:
return None
# Recursively collect all files
def collect_files(dir_path: str) -> list[tuple[str, str]]:
"""Collect all files recursively, returning (full_path, arcname) tuples."""
files: list[tuple[str, str]] = []
try:
entries = self._sandbox_manager.list_directory(
sandbox_id=sandbox.id,
session_id=session_id,
path=dir_path,
)
for entry in entries:
if entry.is_directory:
files.extend(collect_files(entry.path))
else:
# arcname is relative to the target directory
prefix_len = len(path) + 1 # +1 for trailing slash
arcname = entry.path[prefix_len:]
files.append((entry.path, arcname))
except ValueError:
pass
return files
file_list = collect_files(path)
# Create zip file in memory
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
for full_path, arcname in file_list:
try:
content = self._sandbox_manager.read_file(
sandbox_id=sandbox.id,
session_id=session_id,
path=full_path,
)
zip_file.writestr(arcname, content)
except ValueError:
pass
zip_buffer.seek(0)
# Use the directory name for the zip filename
dir_name = Path(path).name
safe_name = "".join(
c if c.isalnum() or c in ("-", "_", ".") else "_" for c in dir_name
)
filename = f"{safe_name}.zip"
return zip_buffer.getvalue(), filename
# =========================================================================
# File System Operations
# =========================================================================
@@ -1937,11 +2057,18 @@ class SessionManager:
return None
# Use sandbox manager to list directory (works for both local and K8s)
raw_entries = self._sandbox_manager.list_directory(
sandbox_id=sandbox.id,
session_id=session_id,
path=path,
)
# If the directory doesn't exist (e.g., session workspace not yet loaded),
# return an empty listing rather than erroring out.
try:
raw_entries = self._sandbox_manager.list_directory(
sandbox_id=sandbox.id,
session_id=session_id,
path=path,
)
except ValueError as e:
if "path traversal" in str(e).lower():
raise
return DirectoryListing(path=path, entries=[])
# Filter hidden files and directories
entries: list[FileSystemEntry] = [

View File

@@ -12,11 +12,18 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.background.celery.tasks.user_file_processing.tasks import (
enqueue_user_file_project_sync_task,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
get_user_file_project_sync_queue_depth,
)
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import UserFileStatus
from onyx.db.models import ChatSession
@@ -27,6 +34,7 @@ from onyx.db.models import UserProject
from onyx.db.persona import get_personas_by_ids
from onyx.db.projects import get_project_token_count
from onyx.db.projects import upload_files_to_user_files_with_indexing
from onyx.redis.redis_pool import get_redis_client
from onyx.server.features.projects.models import CategorizedFilesSnapshot
from onyx.server.features.projects.models import ChatSessionRequest
from onyx.server.features.projects.models import TokenCountResponse
@@ -47,6 +55,33 @@ class UserFileDeleteResult(BaseModel):
assistant_names: list[str] = []
def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None:
queue_depth = get_user_file_project_sync_queue_depth(client_app)
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
logger.warning(
f"Skipping immediate project sync for user_file_id={user_file_id} due to "
f"queue depth {queue_depth}>{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}. "
"It will be picked up by beat later."
)
return
redis_client = get_redis_client(tenant_id=tenant_id)
enqueued = enqueue_user_file_project_sync_task(
celery_app=client_app,
redis_client=redis_client,
user_file_id=user_file_id,
tenant_id=tenant_id,
priority=OnyxCeleryPriority.HIGHEST,
)
if not enqueued:
logger.info(
f"Skipped duplicate project sync enqueue for user_file_id={user_file_id}"
)
return
logger.info(f"Triggered project sync for user_file_id={user_file_id}")
@router.get("", tags=PUBLIC_API_TAGS)
def get_projects(
user: User = Depends(current_user),
@@ -189,15 +224,7 @@ def unlink_user_file_from_project(
db_session.commit()
tenant_id = get_current_tenant_id()
task = client_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=OnyxCeleryPriority.HIGHEST,
)
logger.info(
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
)
_trigger_user_file_project_sync(user_file.id, tenant_id)
return Response(status_code=204)
@@ -241,15 +268,7 @@ def link_user_file_to_project(
db_session.commit()
tenant_id = get_current_tenant_id()
task = client_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=OnyxCeleryPriority.HIGHEST,
)
logger.info(
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
)
_trigger_user_file_project_sync(user_file.id, tenant_id)
return UserFileSnapshot.from_model(user_file)

View File

@@ -0,0 +1,47 @@
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.db.code_interpreter import fetch_code_interpreter_server
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.server.manage.code_interpreter.models import CodeInterpreterServer
from onyx.server.manage.code_interpreter.models import CodeInterpreterServerHealth
from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
admin_router = APIRouter(prefix="/admin/code-interpreter")
@admin_router.get("/health")
def get_code_interpreter_health(
_: User = Depends(current_admin_user),
) -> CodeInterpreterServerHealth:
try:
client = CodeInterpreterClient()
return CodeInterpreterServerHealth(healthy=client.health())
except ValueError:
return CodeInterpreterServerHealth(healthy=False)
@admin_router.get("")
def get_code_interpreter(
_: User = Depends(current_admin_user), db_session: Session = Depends(get_session)
) -> CodeInterpreterServer:
ci_server = fetch_code_interpreter_server(db_session)
return CodeInterpreterServer(enabled=ci_server.server_enabled)
@admin_router.put("")
def update_code_interpreter(
update: CodeInterpreterServer,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_code_interpreter_server_enabled(
db_session=db_session,
enabled=update.enabled,
)

View File

@@ -0,0 +1,9 @@
from pydantic import BaseModel
class CodeInterpreterServer(BaseModel):
enabled: bool
class CodeInterpreterServerHealth(BaseModel):
healthy: bool

View File

@@ -35,6 +35,18 @@ if TYPE_CHECKING:
pass
class EmailInviteStatus(str, Enum):
SENT = "SENT"
NOT_CONFIGURED = "NOT_CONFIGURED"
SEND_FAILED = "SEND_FAILED"
DISABLED = "DISABLED"
class BulkInviteResponse(BaseModel):
invited_count: int
email_invite_status: EmailInviteStatus
class VersionResponse(BaseModel):
backend_version: str

View File

@@ -36,6 +36,7 @@ from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import AuthBackend
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
from onyx.configs.app_configs import NUM_FREE_TRIAL_USER_INVITES
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
@@ -78,8 +79,10 @@ from onyx.server.documents.models import PaginatedReturn
from onyx.server.features.projects.models import UserFileSnapshot
from onyx.server.manage.models import AllUsersResponse
from onyx.server.manage.models import AutoScrollRequest
from onyx.server.manage.models import BulkInviteResponse
from onyx.server.manage.models import ChatBackgroundRequest
from onyx.server.manage.models import DefaultAppModeRequest
from onyx.server.manage.models import EmailInviteStatus
from onyx.server.manage.models import MemoryItem
from onyx.server.manage.models import PersonalizationUpdateRequest
from onyx.server.manage.models import TenantInfo
@@ -368,7 +371,7 @@ def bulk_invite_users(
emails: list[str] = Body(..., embed=True),
current_user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> int:
) -> BulkInviteResponse:
"""emails are string validated. If any email fails validation, no emails are
invited and an exception is raised."""
tenant_id = get_current_tenant_id()
@@ -427,34 +430,41 @@ def bulk_invite_users(
number_of_invited_users = write_invited_users(all_emails)
# send out email invitations only to new users (not already invited or existing)
if ENABLE_EMAIL_INVITES:
if not ENABLE_EMAIL_INVITES:
email_invite_status = EmailInviteStatus.DISABLED
elif not EMAIL_CONFIGURED:
email_invite_status = EmailInviteStatus.NOT_CONFIGURED
else:
try:
for email in emails_needing_seats:
send_user_email_invite(email, current_user, AUTH_TYPE)
email_invite_status = EmailInviteStatus.SENT
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
email_invite_status = EmailInviteStatus.SEND_FAILED
if not MULTI_TENANT or DEV_MODE:
return number_of_invited_users
if MULTI_TENANT and not DEV_MODE:
# for billing purposes, write to the control plane about the number of new users
try:
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_live_users_count(db_session))
except Exception as e:
logger.error(f"Failed to register tenant users: {str(e)}")
logger.info(
"Reverting changes: removing users from tenant and resetting invited users"
)
write_invited_users(initial_invited_users) # Reset to original state
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)(new_invited_emails, tenant_id)
raise e
# for billing purposes, write to the control plane about the number of new users
try:
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_live_users_count(db_session))
return number_of_invited_users
except Exception as e:
logger.error(f"Failed to register tenant users: {str(e)}")
logger.info(
"Reverting changes: removing users from tenant and resetting invited users"
)
write_invited_users(initial_invited_users) # Reset to original state
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)(new_invited_emails, tenant_id)
raise e
return BulkInviteResponse(
invited_count=number_of_invited_users,
email_invite_status=email_invite_status,
)
@router.patch("/manage/admin/remove-invited-user", tags=PUBLIC_API_TAGS)

View File

@@ -587,6 +587,7 @@ def handle_send_chat_message(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
additional_context=chat_message_req.additional_context,
external_state_container=state_container,
)
result = gather_stream_full(packets, state_container)
@@ -609,6 +610,7 @@ def handle_send_chat_message(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
additional_context=chat_message_req.additional_context,
external_state_container=state_container,
):
yield get_json_line(obj.model_dump())

View File

@@ -125,6 +125,11 @@ class SendMessageRequest(BaseModel):
# - No CitationInfo packets are emitted during streaming
include_citations: bool = True
# Additional context injected into the LLM call but NOT stored in the DB
# (not shown in chat history). Used e.g. by the Chrome extension to pass
# the current tab URL when "Read this tab" is enabled.
additional_context: str | None = None
@model_validator(mode="after")
def check_chat_session_id_or_info(self) -> "SendMessageRequest":
# If neither is provided, default to creating a new chat session using the

View File

@@ -1,5 +1,8 @@
import json
from collections.abc import Generator
from typing import Literal
from typing import TypedDict
from typing import Union
import requests
from pydantic import BaseModel
@@ -36,6 +39,39 @@ class ExecuteResponse(BaseModel):
files: list[WorkspaceFile]
class StreamOutputEvent(BaseModel):
"""SSE 'output' event: a chunk of stdout or stderr"""
stream: Literal["stdout", "stderr"]
data: str
class StreamResultEvent(BaseModel):
"""SSE 'result' event: final execution result"""
exit_code: int | None
timed_out: bool
duration_ms: int
files: list[WorkspaceFile]
class StreamErrorEvent(BaseModel):
"""SSE 'error' event: execution-level error"""
message: str
StreamEvent = Union[StreamOutputEvent, StreamResultEvent, StreamErrorEvent]
_SSE_EVENT_MAP: dict[
str, type[StreamOutputEvent | StreamResultEvent | StreamErrorEvent]
] = {
"output": StreamOutputEvent,
"result": StreamResultEvent,
"error": StreamErrorEvent,
}
class CodeInterpreterClient:
"""Client for Code Interpreter service"""
@@ -45,6 +81,34 @@ class CodeInterpreterClient:
self.base_url = base_url.rstrip("/")
self.session = requests.Session()
def _build_payload(
self,
code: str,
stdin: str | None,
timeout_ms: int,
files: list[FileInput] | None,
) -> dict:
payload: dict = {
"code": code,
"timeout_ms": timeout_ms,
}
if stdin is not None:
payload["stdin"] = stdin
if files:
payload["files"] = files
return payload
def health(self) -> bool:
"""Check if the Code Interpreter service is healthy"""
url = f"{self.base_url}/health"
try:
response = self.session.get(url, timeout=5)
response.raise_for_status()
return response.json().get("status") == "ok"
except Exception as e:
logger.warning(f"Exception caught when checking health, e={e}")
return False
def execute(
self,
code: str,
@@ -52,25 +116,110 @@ class CodeInterpreterClient:
timeout_ms: int = 30000,
files: list[FileInput] | None = None,
) -> ExecuteResponse:
"""Execute Python code"""
"""Execute Python code (batch)"""
url = f"{self.base_url}/v1/execute"
payload = {
"code": code,
"timeout_ms": timeout_ms,
}
if stdin is not None:
payload["stdin"] = stdin
if files:
payload["files"] = files
payload = self._build_payload(code, stdin, timeout_ms, files)
response = self.session.post(url, json=payload, timeout=timeout_ms / 1000 + 10)
response.raise_for_status()
return ExecuteResponse(**response.json())
def execute_streaming(
self,
code: str,
stdin: str | None = None,
timeout_ms: int = 30000,
files: list[FileInput] | None = None,
) -> Generator[StreamEvent, None, None]:
"""Execute Python code with streaming SSE output.
Yields StreamEvent objects (StreamOutputEvent, StreamResultEvent,
StreamErrorEvent) as execution progresses. Falls back to batch
execution if the streaming endpoint is not available (older
code-interpreter versions).
"""
url = f"{self.base_url}/v1/execute/stream"
payload = self._build_payload(code, stdin, timeout_ms, files)
response = self.session.post(
url,
json=payload,
stream=True,
timeout=timeout_ms / 1000 + 10,
)
if response.status_code == 404:
logger.info(
"Streaming endpoint not available, " "falling back to batch execution"
)
response.close()
yield from self._batch_as_stream(code, stdin, timeout_ms, files)
return
response.raise_for_status()
yield from self._parse_sse(response)
def _parse_sse(
self, response: requests.Response
) -> Generator[StreamEvent, None, None]:
"""Parse SSE streaming response into StreamEvent objects.
Expected format per event:
event: <type>
data: <json>
<blank line>
"""
event_type: str | None = None
data_lines: list[str] = []
for line in response.iter_lines(decode_unicode=True):
if line is None:
continue
if line == "":
# Blank line marks end of an SSE event
if event_type is not None and data_lines:
data = "\n".join(data_lines)
model_cls = _SSE_EVENT_MAP.get(event_type)
if model_cls is not None:
yield model_cls(**json.loads(data))
else:
logger.warning(f"Unknown SSE event type: {event_type}")
event_type = None
data_lines = []
elif line.startswith("event:"):
event_type = line[len("event:") :].strip()
elif line.startswith("data:"):
data_lines.append(line[len("data:") :].strip())
if event_type is not None or data_lines:
logger.warning(
f"SSE stream ended with incomplete event: "
f"event_type={event_type}, data_lines={data_lines}"
)
def _batch_as_stream(
self,
code: str,
stdin: str | None,
timeout_ms: int,
files: list[FileInput] | None,
) -> Generator[StreamEvent, None, None]:
"""Execute via batch endpoint and yield results as stream events."""
result = self.execute(code, stdin, timeout_ms, files)
if result.stdout:
yield StreamOutputEvent(stream="stdout", data=result.stdout)
if result.stderr:
yield StreamOutputEvent(stream="stderr", data=result.stderr)
yield StreamResultEvent(
exit_code=result.exit_code,
timed_out=result.timed_out,
duration_ms=result.duration_ms,
files=result.files,
)
def upload_file(self, file_content: bytes, filename: str) -> str:
"""Upload file to Code Interpreter and return file_id"""
url = f"{self.base_url}/v1/files"

View File

@@ -12,6 +12,7 @@ from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
from onyx.configs.app_configs import CODE_INTERPRETER_DEFAULT_TIMEOUT_MS
from onyx.configs.app_configs import CODE_INTERPRETER_MAX_OUTPUT_LENGTH
from onyx.configs.constants import FileOrigin
from onyx.db.code_interpreter import fetch_code_interpreter_server
from onyx.file_store.utils import build_full_frontend_file_url
from onyx.file_store.utils import get_default_file_store
from onyx.server.query_and_chat.placement import Placement
@@ -28,6 +29,15 @@ from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
from onyx.tools.tool_implementations.python.code_interpreter_client import FileInput
from onyx.tools.tool_implementations.python.code_interpreter_client import (
StreamErrorEvent,
)
from onyx.tools.tool_implementations.python.code_interpreter_client import (
StreamOutputEvent,
)
from onyx.tools.tool_implementations.python.code_interpreter_client import (
StreamResultEvent,
)
from onyx.utils.logger import setup_logger
@@ -94,8 +104,10 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
@override
@classmethod
def is_available(cls, db_session: Session) -> bool:
is_available = bool(CODE_INTERPRETER_BASE_URL)
return is_available
if not CODE_INTERPRETER_BASE_URL:
return False
server = fetch_code_interpreter_server(db_session)
return server.server_enabled
def tool_definition(self) -> dict:
return {
@@ -181,19 +193,50 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
try:
logger.debug(f"Executing code: {code}")
# Execute code with timeout
response = client.execute(
# Execute code with streaming (falls back to batch if unavailable)
stdout_parts: list[str] = []
stderr_parts: list[str] = []
result_event: StreamResultEvent | None = None
for event in client.execute_streaming(
code=code,
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
files=files_to_stage or None,
)
):
if isinstance(event, StreamOutputEvent):
if event.stream == "stdout":
stdout_parts.append(event.data)
else:
stderr_parts.append(event.data)
# Emit incremental delta to frontend
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=event.data if event.stream == "stdout" else "",
stderr=event.data if event.stream == "stderr" else "",
),
)
)
elif isinstance(event, StreamResultEvent):
result_event = event
elif isinstance(event, StreamErrorEvent):
raise RuntimeError(f"Code interpreter error: {event.message}")
if result_event is None:
raise RuntimeError(
"Code interpreter stream ended without a result event"
)
full_stdout = "".join(stdout_parts)
full_stderr = "".join(stderr_parts)
# Truncate output for LLM consumption
truncated_stdout = _truncate_output(
response.stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
)
truncated_stderr = _truncate_output(
response.stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
)
# Handle generated files
@@ -202,7 +245,7 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
file_ids_to_cleanup: list[str] = []
file_store = get_default_file_store()
for workspace_file in response.files:
for workspace_file in result_event.files:
if workspace_file.kind != "file" or not workspace_file.file_id:
continue
@@ -258,26 +301,23 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
)
# Emit delta with stdout/stderr and generated files
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=truncated_stdout,
stderr=truncated_stderr,
file_ids=generated_file_ids,
),
# Emit file_ids once files are processed
if generated_file_ids:
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(file_ids=generated_file_ids),
)
)
)
# Build result
result = LlmPythonExecutionResult(
stdout=truncated_stdout,
stderr=truncated_stderr,
exit_code=response.exit_code,
timed_out=response.timed_out,
exit_code=result_event.exit_code,
timed_out=result_event.timed_out,
generated_files=generated_files,
error=None if response.exit_code == 0 else truncated_stderr,
error=None if result_event.exit_code == 0 else truncated_stderr,
)
# Serialize result for LLM

View File

@@ -6,6 +6,8 @@ aioboto3==15.1.0
# via onyx
aiobotocore==2.24.0
# via aioboto3
aiofile==3.9.0
# via py-key-value-aio
aiofiles==25.1.0
# via
# aioboto3
@@ -40,8 +42,10 @@ anyio==4.11.0
# httpx
# mcp
# openai
# py-key-value-aio
# sse-starlette
# starlette
# watchfiles
argon2-cffi==23.1.0
# via pwdlib
argon2-cffi-bindings==25.1.0
@@ -74,9 +78,7 @@ backports-tarfile==1.2.0 ; python_full_version < '3.12'
bcrypt==4.3.0
# via pwdlib
beartype==0.22.6
# via
# py-key-value-aio
# py-key-value-shared
# via py-key-value-aio
beautifulsoup4==4.12.3
# via
# atlassian-python-api
@@ -110,6 +112,8 @@ cachetools==6.2.2
# via
# google-auth
# py-key-value-aio
caio==0.9.25
# via aiofile
celery==5.5.1
# via onyx
certifi==2025.11.12
@@ -170,7 +174,6 @@ cloudpickle==3.1.2
# via
# dask
# distributed
# pydocket
cobble==0.1.4
# via mammoth
cohere==5.6.1
@@ -218,8 +221,6 @@ deprecated==1.3.1
# pygithub
discord-py==2.4.0
# via onyx
diskcache==5.6.3
# via py-key-value-aio
distributed==2026.1.1
# via onyx
distro==1.9.0
@@ -256,8 +257,6 @@ exceptiongroup==1.3.0
# via
# braintrust
# fastmcp
fakeredis==2.33.0
# via pydocket
fastapi==0.128.0
# via
# fastapi-limiter
@@ -273,7 +272,7 @@ fastapi-users-db-sqlalchemy==7.0.0
# via onyx
fastavro==1.12.1
# via cohere
fastmcp==2.14.2
fastmcp==3.0.2
# via onyx
fastuuid==0.14.0
# via litellm
@@ -478,7 +477,9 @@ jsonpatch==1.33
jsonpointer==3.0.0
# via jsonpatch
jsonref==1.1.0
# via onyx
# via
# fastmcp
# onyx
jsonschema==4.25.1
# via
# litellm
@@ -513,8 +514,6 @@ locket==1.0.0
# via
# distributed
# partd
lupa==2.6
# via fakeredis
lxml==5.3.0
# via
# htmldate
@@ -556,7 +555,7 @@ marshmallow==3.26.2
# via dataclasses-json
matrix-client==0.3.2
# via zulip
mcp==1.25.0
mcp==1.26.0
# via
# claude-agent-sdk
# fastmcp
@@ -613,7 +612,7 @@ oauthlib==3.2.2
# kubernetes
# onyx
# requests-oauthlib
office365-rest-python-client==2.5.9
office365-rest-python-client==2.6.2
# via onyx
olefile==0.47
# via
@@ -642,22 +641,16 @@ opensearch-py==3.0.0
opentelemetry-api==1.39.1
# via
# ddtrace
# fastmcp
# langfuse
# openinference-instrumentation
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-exporter-prometheus
# opentelemetry-instrumentation
# opentelemetry-sdk
# opentelemetry-semantic-conventions
# pydocket
opentelemetry-exporter-otlp-proto-common==1.39.1
# via opentelemetry-exporter-otlp-proto-http
opentelemetry-exporter-otlp-proto-http==1.39.1
# via langfuse
opentelemetry-exporter-prometheus==0.60b1
# via pydocket
opentelemetry-instrumentation==0.60b1
# via pydocket
opentelemetry-proto==1.39.1
# via
# onyx
@@ -668,17 +661,15 @@ opentelemetry-sdk==1.39.1
# langfuse
# openinference-instrumentation
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-exporter-prometheus
opentelemetry-semantic-conventions==0.60b1
# via
# opentelemetry-instrumentation
# opentelemetry-sdk
# via opentelemetry-sdk
orjson==3.11.4 ; platform_python_implementation != 'PyPy'
# via langsmith
packaging==24.2
# via
# dask
# distributed
# fastmcp
# google-cloud-aiplatform
# google-cloud-bigquery
# huggingface-hub
@@ -689,7 +680,6 @@ packaging==24.2
# langsmith
# marshmallow
# onnxruntime
# opentelemetry-instrumentation
# pytest
# pywikibot
pandas==2.3.3
@@ -702,8 +692,6 @@ passlib==1.7.4
# via onyx
pathable==0.4.4
# via jsonschema-path
pathvalidate==3.3.1
# via py-key-value-aio
pdfminer-six==20251107
# via markitdown
pillow==12.1.1
@@ -723,9 +711,7 @@ ply==3.11
prometheus-client==0.23.1
# via
# onyx
# opentelemetry-exporter-prometheus
# prometheus-fastapi-instrumentator
# pydocket
prometheus-fastapi-instrumentator==7.1.0
# via onyx
prompt-toolkit==3.0.52
@@ -764,12 +750,8 @@ pwdlib==0.3.0
# via fastapi-users
py==1.11.0
# via retry
py-key-value-aio==0.3.0
# via
# fastmcp
# pydocket
py-key-value-shared==0.3.0
# via py-key-value-aio
py-key-value-aio==0.4.4
# via fastmcp
pyairtable==3.0.1
# via onyx
pyasn1==0.6.2
@@ -806,8 +788,6 @@ pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.12.0
# via mcp
pydocket==0.16.3
# via fastmcp
pyee==13.0.0
# via playwright
pygithub==2.5.0
@@ -879,8 +859,6 @@ python-http-client==3.3.7
# via sendgrid
python-iso639==2025.11.16
# via unstructured
python-json-logger==4.0.0
# via pydocket
python-magic==0.4.27
# via unstructured
python-multipart==0.0.22
@@ -918,6 +896,7 @@ pyyaml==6.0.3
# via
# dask
# distributed
# fastmcp
# huggingface-hub
# jsonschema-path
# kubernetes
@@ -928,11 +907,8 @@ rapidfuzz==3.13.0
# unstructured
redis==5.0.8
# via
# fakeredis
# fastapi-limiter
# onyx
# py-key-value-aio
# pydocket
referencing==0.36.2
# via
# jsonschema
@@ -1007,7 +983,6 @@ rich==14.2.0
# via
# cyclopts
# fastmcp
# pydocket
# rich-rst
# typer
rich-rst==1.3.2
@@ -1056,9 +1031,7 @@ sniffio==1.3.1
# anyio
# openai
sortedcontainers==2.4.0
# via
# distributed
# fakeredis
# via distributed
soupsieve==2.8
# via beautifulsoup4
sqlalchemy==2.0.15
@@ -1124,9 +1097,7 @@ tqdm==4.67.1
trafilatura==1.12.2
# via onyx
typer==0.20.0
# via
# mcp
# pydocket
# via mcp
types-awscrt==0.28.4
# via botocore-stubs
types-openpyxl==3.0.4.7
@@ -1162,11 +1133,10 @@ typing-extensions==4.15.0
# opentelemetry-exporter-otlp-proto-http
# opentelemetry-sdk
# opentelemetry-semantic-conventions
# py-key-value-shared
# py-key-value-aio
# pyairtable
# pydantic
# pydantic-core
# pydocket
# pyee
# pygithub
# python-docx
@@ -1234,6 +1204,8 @@ vine==5.1.0
# kombu
voyageai==0.2.3
# via onyx
watchfiles==1.1.1
# via fastmcp
wcwidth==0.2.14
# via prompt-toolkit
webencodings==0.5.1
@@ -1254,7 +1226,6 @@ wrapt==1.17.3
# deprecated
# langfuse
# openinference-instrumentation
# opentelemetry-instrumentation
# unstructured
xlrd==2.0.2
# via markitdown

View File

@@ -288,7 +288,7 @@ matplotlib-inline==0.2.1
# via
# ipykernel
# ipython
mcp==1.25.0
mcp==1.26.0
# via claude-agent-sdk
multidict==6.7.0
# via
@@ -317,7 +317,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.6.0
onyx-devtools==0.6.1
# via onyx
openai==2.14.0
# via

View File

@@ -211,7 +211,7 @@ litellm==1.81.6
# via onyx
markupsafe==3.0.3
# via jinja2
mcp==1.25.0
mcp==1.26.0
# via claude-agent-sdk
monotonic==1.6
# via posthog

View File

@@ -246,7 +246,7 @@ litellm==1.81.6
# via onyx
markupsafe==3.0.3
# via jinja2
mcp==1.25.0
mcp==1.26.0
# via claude-agent-sdk
mpmath==1.3.0
# via sympy

View File

@@ -95,6 +95,7 @@ def generate_dummy_chunk(
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
user_project=[],
personas=[],
access=DocumentAccess.build(
user_emails=user_emails,
user_groups=user_groups,

View File

@@ -3,8 +3,8 @@ set -e
cleanup() {
echo "Error occurred. Cleaning up..."
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
}
# Trap errors and output a message, then cleanup
@@ -20,8 +20,8 @@ MINIO_VOLUME=${4:-""} # Default is empty if not provided
# Stop and remove the existing containers
echo "Stopping and removing existing containers..."
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
# Start the PostgreSQL container with optional volume
echo "Starting PostgreSQL container..."
@@ -55,6 +55,10 @@ else
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
fi
# Start the Code Interpreter container
echo "Starting Code Interpreter container..."
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
# Ensure alembic runs in the correct directory (backend/)
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PARENT_DIR="$(dirname "$SCRIPT_DIR")"

View File

@@ -243,12 +243,12 @@ USAGE_LIMIT_CHUNKS_INDEXED_PAID = int(
)
# Per-week API calls using API keys or Personal Access Tokens
USAGE_LIMIT_API_CALLS_TRIAL = int(os.environ.get("USAGE_LIMIT_API_CALLS_TRIAL", "400"))
USAGE_LIMIT_API_CALLS_TRIAL = int(os.environ.get("USAGE_LIMIT_API_CALLS_TRIAL", "0"))
USAGE_LIMIT_API_CALLS_PAID = int(os.environ.get("USAGE_LIMIT_API_CALLS_PAID", "40000"))
# Per-week non-streaming API calls (more expensive, so lower limits)
USAGE_LIMIT_NON_STREAMING_CALLS_TRIAL = int(
os.environ.get("USAGE_LIMIT_NON_STREAMING_CALLS_TRIAL", "80")
os.environ.get("USAGE_LIMIT_NON_STREAMING_CALLS_TRIAL", "0")
)
USAGE_LIMIT_NON_STREAMING_CALLS_PAID = int(
os.environ.get("USAGE_LIMIT_NON_STREAMING_CALLS_PAID", "160")

View File

@@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator
from collections.abc import Generator
from contextlib import asynccontextmanager
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from dotenv import load_dotenv
@@ -46,11 +47,15 @@ def mock_current_admin_user() -> MagicMock:
@pytest.fixture(scope="function")
def client() -> Generator[TestClient, None, None]:
# Initialize TestClient with the FastAPI app using a no-op test lifespan
# Initialize TestClient with the FastAPI app using a no-op test lifespan.
# Patch out prometheus metrics setup to avoid "Duplicated timeseries in
# CollectorRegistry" errors when multiple tests each create a new app
# (prometheus registers metrics globally and rejects duplicate names).
get_app = fetch_versioned_implementation(
module="onyx.main", attribute="get_application"
)
app: FastAPI = get_app(lifespan_override=test_lifespan)
with patch("onyx.main.setup_prometheus_metrics"):
app: FastAPI = get_app(lifespan_override=test_lifespan)
# Override the database session dependency with a mock
# (these tests don't actually need DB access)

View File

@@ -990,6 +990,27 @@ class _MockCIHandler(BaseHTTPRequestHandler):
self._respond_json(
200, {"file_id": f"mock-ci-file-{self.server._file_counter}"}
)
elif self.path == "/v1/execute/stream":
if self.server.streaming_enabled:
self._respond_sse(
[
(
"output",
{"stream": "stdout", "data": "mock output\n"},
),
(
"result",
{
"exit_code": 0,
"timed_out": False,
"duration_ms": 50,
"files": [],
},
),
]
)
else:
self._respond_json(404, {"error": "not found"})
elif self.path == "/v1/execute":
self._respond_json(
200,
@@ -1027,6 +1048,17 @@ class _MockCIHandler(BaseHTTPRequestHandler):
self.end_headers()
self.wfile.write(payload)
def _respond_sse(self, events: list[tuple[str, dict[str, Any]]]) -> None:
frames = []
for event_type, data in events:
frames.append(f"event: {event_type}\ndata: {json.dumps(data)}\n\n")
payload = "".join(frames).encode()
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Content-Length", str(len(payload)))
self.end_headers()
self.wfile.write(payload)
def log_message(self, format: str, *args: Any) -> None: # noqa: A002
pass
@@ -1038,6 +1070,7 @@ class MockCodeInterpreterServer(HTTPServer):
super().__init__(("localhost", 0), _MockCIHandler)
self.captured_requests: list[CapturedRequest] = []
self._file_counter = 0
self.streaming_enabled: bool = True
@property
def url(self) -> str:
@@ -1168,17 +1201,19 @@ def test_code_interpreter_receives_chat_files(
finally:
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
# Verify: file uploaded, code executed, staged file cleaned up
# Verify: file uploaded, code executed via streaming, staged file cleaned up
assert len(mock_ci_server.get_requests(method="POST", path="/v1/files")) == 1
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
assert (
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
)
delete_requests = mock_ci_server.get_requests(method="DELETE")
assert len(delete_requests) == 1
assert delete_requests[0].path.startswith("/v1/files/")
execute_body = mock_ci_server.get_requests(method="POST", path="/v1/execute")[
0
].json_body()
execute_body = mock_ci_server.get_requests(
method="POST", path="/v1/execute/stream"
)[0].json_body()
assert execute_body["code"] == code
assert len(execute_body["files"]) == 1
assert execute_body["files"][0]["path"] == "data.csv"
@@ -1284,7 +1319,9 @@ def test_code_interpreter_replay_packets_include_code_and_output(
db_session=db_session,
)
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
assert (
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
)
# The response contains `packets` — a list of packet-lists, one per
# assistant message. We should have exactly one assistant message.
@@ -1313,3 +1350,76 @@ def test_code_interpreter_replay_packets_include_code_and_output(
delta_obj = delta_packets[0].obj
assert isinstance(delta_obj, PythonToolDelta)
assert "mock output" in delta_obj.stdout
def test_code_interpreter_streaming_fallback_to_batch(
db_session: Session,
mock_ci_server: MockCodeInterpreterServer,
_attach_python_tool_to_default_persona: None,
initialize_file_store: None, # noqa: ARG001
) -> None:
"""When the streaming endpoint is not available (older code-interpreter),
execute_streaming should fall back to the batch /v1/execute endpoint."""
mock_ci_server.captured_requests.clear()
mock_ci_server._file_counter = 0
mock_ci_server.streaming_enabled = False
mock_url = mock_ci_server.url
user = create_test_user(db_session, "ci_fallback_test")
chat_session = create_chat_session(db_session=db_session, user=user)
code = 'print("fallback test")'
msg_req = SendMessageRequest(
message="Print fallback test",
chat_session_id=chat_session.id,
stream=True,
)
original_defaults = ci_mod.CodeInterpreterClient.__init__.__defaults__
with (
use_mock_llm() as mock_llm,
patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
mock_url,
),
patch(
"onyx.tools.tool_implementations.python.code_interpreter_client.CODE_INTERPRETER_BASE_URL",
mock_url,
),
):
mock_llm.add_response(
LLMToolCallResponse(
tool_name="python",
tool_call_id="call_fallback",
tool_call_argument_tokens=[json.dumps({"code": code})],
)
)
mock_llm.forward_till_end()
ci_mod.CodeInterpreterClient.__init__.__defaults__ = (mock_url,)
try:
packets = list(
handle_stream_message_objects(
new_msg_req=msg_req, user=user, db_session=db_session
)
)
finally:
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
mock_ci_server.streaming_enabled = True
# Streaming was attempted first (returned 404), then fell back to batch
assert (
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
)
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
# Verify output still made it through
delta_packets = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, PythonToolDelta)
]
assert len(delta_packets) >= 1
first_delta = delta_packets[0].obj
assert isinstance(first_delta, PythonToolDelta)
assert "mock output" in first_delta.stdout

View File

@@ -0,0 +1,53 @@
"""Tests that PythonTool.is_available() respects the server_enabled DB flag.
Uses a real DB session with CODE_INTERPRETER_BASE_URL mocked so the
environment-variable check passes and the DB flag is the deciding factor.
"""
from unittest.mock import patch
from sqlalchemy.orm import Session
from onyx.db.code_interpreter import fetch_code_interpreter_server
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
from onyx.tools.tool_implementations.python.python_tool import PythonTool
def test_python_tool_unavailable_when_server_disabled(
db_session: Session,
) -> None:
"""With a valid base URL, the tool should be unavailable when
server_enabled is False in the DB."""
server = fetch_code_interpreter_server(db_session)
initial_enabled = server.server_enabled
try:
update_code_interpreter_server_enabled(db_session, enabled=False)
with patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
"http://fake:8888",
):
assert PythonTool.is_available(db_session) is False
finally:
update_code_interpreter_server_enabled(db_session, enabled=initial_enabled)
def test_python_tool_available_when_server_enabled(
db_session: Session,
) -> None:
"""With a valid base URL, the tool should be available when
server_enabled is True in the DB."""
server = fetch_code_interpreter_server(db_session)
initial_enabled = server.server_enabled
try:
update_code_interpreter_server_enabled(db_session, enabled=True)
with patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
"http://fake:8888",
):
assert PythonTool.is_available(db_session) is True
finally:
update_code_interpreter_server_enabled(db_session, enabled=initial_enabled)

View File

@@ -38,5 +38,5 @@ COPY --from=openapi-client /local/onyx_openapi_client /app/generated/onyx_openap
ENV PYTHONPATH=/app
ENTRYPOINT ["pytest", "-s"]
ENTRYPOINT ["pytest", "-s", "-rs"]
CMD ["/app/tests/integration", "--ignore=/app/tests/integration/multitenant_tests"]

View File

@@ -1,3 +1,4 @@
import time
from datetime import datetime
from urllib.parse import urlencode
from uuid import UUID
@@ -8,8 +9,10 @@ from requests.models import CaseInsensitiveDict
from ee.onyx.server.query_history.models import ChatSessionMinimal
from ee.onyx.server.query_history.models import ChatSessionSnapshot
from onyx.configs.constants import QAFeedbackType
from onyx.db.enums import TaskStatus
from onyx.server.documents.models import PaginatedReturn
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import DATestUser
@@ -69,9 +72,42 @@ class QueryHistoryManager:
if end_time:
query_params["end"] = end_time.isoformat()
response = requests.get(
url=f"{API_SERVER_URL}/admin/query-history-csv?{urlencode(query_params, doseq=True)}",
start_response = requests.post(
url=f"{API_SERVER_URL}/admin/query-history/start-export?{urlencode(query_params, doseq=True)}",
headers=user_performing_action.headers,
)
response.raise_for_status()
return response.headers, response.content.decode()
start_response.raise_for_status()
request_id = start_response.json()["request_id"]
deadline = time.time() + MAX_DELAY
while time.time() < deadline:
status_response = requests.get(
url=f"{API_SERVER_URL}/admin/query-history/export-status",
params={"request_id": request_id},
headers=user_performing_action.headers,
)
status_response.raise_for_status()
status = status_response.json()["status"]
if status == TaskStatus.SUCCESS:
break
if status == TaskStatus.FAILURE:
raise RuntimeError("Query history export task failed")
time.sleep(2)
else:
raise TimeoutError(
f"Query history export not completed within {MAX_DELAY} seconds"
)
download_response = requests.get(
url=f"{API_SERVER_URL}/admin/query-history/download",
params={"request_id": request_id},
headers=user_performing_action.headers,
)
download_response.raise_for_status()
if not download_response.content:
raise RuntimeError(
"Query history CSV download returned zero-length content"
)
return download_response.headers, download_response.content.decode()

View File

@@ -6,16 +6,26 @@ import pytest
from onyx.connectors.slack.models import ChannelType
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
# from tests.load_env_vars import load_env_vars
# load_env_vars()
SLACK_ADMIN_EMAIL = os.environ.get("SLACK_ADMIN_EMAIL", "evan@onyx.app")
SLACK_TEST_USER_1_EMAIL = os.environ.get("SLACK_TEST_USER_1_EMAIL", "evan+1@onyx.app")
SLACK_TEST_USER_2_EMAIL = os.environ.get("SLACK_TEST_USER_2_EMAIL", "justin@onyx.app")
@pytest.fixture()
def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]:
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
def _provision_slack_channels(
bot_token: str,
) -> Generator[tuple[ChannelType, ChannelType], None, None]:
slack_client = SlackManager.get_slack_client(bot_token)
auth_info = slack_client.auth_test()
print(f"\nSlack workspace: {auth_info.get('team')} ({auth_info.get('url')})")
user_map = SlackManager.build_slack_user_email_id_map(slack_client)
admin_user_id = user_map["admin@example.com"]
if SLACK_ADMIN_EMAIL not in user_map:
raise KeyError(
f"'{SLACK_ADMIN_EMAIL}' not found in Slack workspace. "
f"Available emails: {sorted(user_map.keys())}"
)
admin_user_id = user_map[SLACK_ADMIN_EMAIL]
(
public_channel,
@@ -27,5 +37,16 @@ def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]
yield public_channel, private_channel
# This part will always run after the test, even if it fails
SlackManager.cleanup_after_test(slack_client=slack_client, test_id=run_id)
@pytest.fixture()
def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]:
yield from _provision_slack_channels(os.environ["SLACK_BOT_TOKEN"])
@pytest.fixture()
def slack_perm_sync_test_setup() -> (
Generator[tuple[ChannelType, ChannelType], None, None]
):
yield from _provision_slack_channels(os.environ["SLACK_BOT_TOKEN_TEST_SPACE"])

View File

@@ -22,6 +22,9 @@ from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestCredential
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
from tests.integration.connector_job_tests.slack.conftest import SLACK_ADMIN_EMAIL
from tests.integration.connector_job_tests.slack.conftest import SLACK_TEST_USER_1_EMAIL
from tests.integration.connector_job_tests.slack.conftest import SLACK_TEST_USER_2_EMAIL
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
@@ -34,26 +37,24 @@ from tests.integration.connector_job_tests.slack.slack_api_utils import SlackMan
def test_slack_permission_sync(
reset: None, # noqa: ARG001
vespa_client: vespa_fixture, # noqa: ARG001
slack_test_setup: tuple[ChannelType, ChannelType],
slack_perm_sync_test_setup: tuple[ChannelType, ChannelType],
) -> None:
public_channel, private_channel = slack_test_setup
public_channel, private_channel = slack_perm_sync_test_setup
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(
email="admin@example.com",
email=SLACK_ADMIN_EMAIL,
)
# Creating a non-admin user
test_user_1: DATestUser = UserManager.create(
email="test_user_1@example.com",
email=SLACK_TEST_USER_1_EMAIL,
)
# Creating a non-admin user
test_user_2: DATestUser = UserManager.create(
email="test_user_2@example.com",
email=SLACK_TEST_USER_2_EMAIL,
)
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
bot_token = os.environ["SLACK_BOT_TOKEN_TEST_SPACE"]
slack_client = SlackManager.get_slack_client(bot_token)
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
admin_user_id = email_id_map[admin_user.email]
@@ -63,7 +64,7 @@ def test_slack_permission_sync(
credential: DATestCredential = CredentialManager.create(
source=DocumentSource.SLACK,
credential_json={
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
"slack_bot_token": bot_token,
},
user_performing_action=admin_user,
)
@@ -73,6 +74,7 @@ def test_slack_permission_sync(
source=DocumentSource.SLACK,
connector_specific_config={
"channels": [public_channel["name"], private_channel["name"]],
"include_bot_messages": True,
},
access_type=AccessType.SYNC,
groups=[],
@@ -102,14 +104,11 @@ def test_slack_permission_sync(
public_message = "Steve's favorite number is 809752"
private_message = "Sara's favorite number is 346794"
# Add messages to channels
print(f"\n Adding public message to channel: {public_message}")
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=public_channel,
message=public_message,
)
print(f"\n Adding private message to channel: {private_message}")
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=private_channel,
@@ -127,7 +126,9 @@ def test_slack_permission_sync(
user_performing_action=admin_user,
)
# Run permission sync
# Run permission sync. Since initial_index_should_sync=True for Slack,
# permissions were already set during indexing above — the explicit sync
# should find no changes to apply.
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
@@ -135,59 +136,38 @@ def test_slack_permission_sync(
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
number_of_updated_docs=2,
number_of_updated_docs=0,
user_performing_action=admin_user,
should_wait_for_group_sync=False,
should_wait_for_vespa_sync=False,
)
# Search as admin with access to both channels
print("\nSearching as admin user")
onyx_doc_message_strings = DocumentSearchManager.search_documents(
# Verify admin can see messages from both channels
admin_docs = DocumentSearchManager.search_documents(
query="favorite number",
user_performing_action=admin_user,
)
print(
"\n documents retrieved by admin user: ",
onyx_doc_message_strings,
)
assert public_message in admin_docs
assert private_message in admin_docs
# Ensure admin user can see messages from both channels
assert public_message in onyx_doc_message_strings
assert private_message in onyx_doc_message_strings
# Search as test_user_2 with access to only the public channel
print("\n Searching as test_user_2")
onyx_doc_message_strings = DocumentSearchManager.search_documents(
# Verify test_user_2 can only see public channel messages
user_2_docs = DocumentSearchManager.search_documents(
query="favorite number",
user_performing_action=test_user_2,
)
print(
"\n documents retrieved by test_user_2: ",
onyx_doc_message_strings,
)
assert public_message in user_2_docs
assert private_message not in user_2_docs
# Ensure test_user_2 can only see messages from the public channel
assert public_message in onyx_doc_message_strings
assert private_message not in onyx_doc_message_strings
# Search as test_user_1 with access to both channels
print("\n Searching as test_user_1")
onyx_doc_message_strings = DocumentSearchManager.search_documents(
# Verify test_user_1 can see both channels (member of private channel)
user_1_docs = DocumentSearchManager.search_documents(
query="favorite number",
user_performing_action=test_user_1,
)
print(
"\n documents retrieved by test_user_1 before being removed from private channel: ",
onyx_doc_message_strings,
)
assert public_message in user_1_docs
assert private_message in user_1_docs
# Ensure test_user_1 can see messages from both channels
assert public_message in onyx_doc_message_strings
assert private_message in onyx_doc_message_strings
# ----------------------MAKE THE CHANGES--------------------------
print("\n Removing test_user_1 from the private channel")
before = datetime.now(timezone.utc)
# Remove test_user_1 from the private channel
before = datetime.now(timezone.utc)
desired_channel_members = [admin_user]
SlackManager.set_channel_members(
slack_client=slack_client,
@@ -206,24 +186,16 @@ def test_slack_permission_sync(
after=before,
number_of_updated_docs=1,
user_performing_action=admin_user,
should_wait_for_group_sync=False,
)
# ----------------------------VERIFY THE CHANGES---------------------------
# Ensure test_user_1 can no longer see messages from the private channel
# Search as test_user_1 with access to only the public channel
onyx_doc_message_strings = DocumentSearchManager.search_documents(
# Verify test_user_1 can no longer see private channel after removal
user_1_docs = DocumentSearchManager.search_documents(
query="favorite number",
user_performing_action=test_user_1,
)
print(
"\n documents retrieved by test_user_1 after being removed from private channel: ",
onyx_doc_message_strings,
)
# Ensure test_user_1 can only see messages from the public channel
assert public_message in onyx_doc_message_strings
assert private_message not in onyx_doc_message_strings
assert public_message in user_1_docs
assert private_message not in user_1_docs
# NOTE(rkuo): it isn't yet clear if the reason these were previously xfail'd
@@ -235,21 +207,19 @@ def test_slack_permission_sync(
def test_slack_group_permission_sync(
reset: None, # noqa: ARG001
vespa_client: vespa_fixture, # noqa: ARG001
slack_test_setup: tuple[ChannelType, ChannelType],
slack_perm_sync_test_setup: tuple[ChannelType, ChannelType],
) -> None:
"""
This test ensures that permission sync overrides onyx group access.
"""
public_channel, private_channel = slack_test_setup
public_channel, private_channel = slack_perm_sync_test_setup
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(
email="admin@example.com",
email=SLACK_ADMIN_EMAIL,
)
# Creating a non-admin user
test_user_1: DATestUser = UserManager.create(
email="test_user_1@example.com",
email=SLACK_TEST_USER_1_EMAIL,
)
# Create a user group and adding the non-admin user to it
@@ -264,7 +234,8 @@ def test_slack_group_permission_sync(
user_performing_action=admin_user,
)
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
bot_token = os.environ["SLACK_BOT_TOKEN_TEST_SPACE"]
slack_client = SlackManager.get_slack_client(bot_token)
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
admin_user_id = email_id_map[admin_user.email]
@@ -282,7 +253,7 @@ def test_slack_group_permission_sync(
credential = CredentialManager.create(
source=DocumentSource.SLACK,
credential_json={
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
"slack_bot_token": bot_token,
},
user_performing_action=admin_user,
)
@@ -294,6 +265,7 @@ def test_slack_group_permission_sync(
source=DocumentSource.SLACK,
connector_specific_config={
"channels": [private_channel["name"]],
"include_bot_messages": True,
},
access_type=AccessType.SYNC,
groups=[user_group.id],
@@ -326,7 +298,8 @@ def test_slack_group_permission_sync(
user_performing_action=admin_user,
)
# Run permission sync
# Run permission sync. Since initial_index_should_sync=True for Slack,
# permissions were already set during indexing — no changes expected.
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
@@ -334,8 +307,10 @@ def test_slack_group_permission_sync(
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
number_of_updated_docs=1,
number_of_updated_docs=0,
user_performing_action=admin_user,
should_wait_for_group_sync=False,
should_wait_for_vespa_sync=False,
)
# Verify admin can see the message

View File

@@ -5,22 +5,17 @@ from fastapi import FastAPI
from fastapi.responses import PlainTextResponse
from fastmcp import FastMCP
from fastmcp.server.auth import StaticTokenVerifier
from fastmcp.server.server import FunctionTool
def make_many_tools(mcp: FastMCP) -> list[FunctionTool]:
def make_tool(i: int) -> FunctionTool:
def make_many_tools(mcp: FastMCP) -> None:
def make_tool(i: int) -> None:
@mcp.tool(name=f"tool_{i}", description=f"Get secret value {i}")
def tool_name(name: str) -> str: # noqa: ARG001
"""Get secret value."""
return f"Secret value {200 - i}!"
return tool_name
tools = []
for i in range(100):
tools.append(make_tool(i))
return tools
make_tool(i)
if __name__ == "__main__":

View File

@@ -28,7 +28,6 @@ from fastmcp import FastMCP
from fastmcp.server.auth import AccessToken
from fastmcp.server.auth import TokenVerifier
from fastmcp.server.dependencies import get_access_token
from fastmcp.server.server import FunctionTool
# Google's tokeninfo endpoint for validating access tokens
GOOGLE_TOKENINFO_URL = "https://oauth2.googleapis.com/tokeninfo"
@@ -148,24 +147,19 @@ class GoogleOAuthTokenVerifier(TokenVerifier):
await self._http_client.aclose()
def make_tools(mcp: FastMCP) -> list[FunctionTool]:
def make_tools(mcp: FastMCP) -> None:
"""Create test tools for the MCP server."""
tools: list[FunctionTool] = []
@mcp.tool(name="echo", description="Echo back the input message")
def echo(message: str) -> str:
"""Echo the message back to the caller."""
return f"You said: {message}"
tools.append(echo)
@mcp.tool(name="get_secret", description="Get a secret value (requires auth)")
def get_secret(secret_name: str) -> str:
"""Get a secret value. This proves the token was validated."""
return f"Secret value for '{secret_name}': super-secret-value-12345"
tools.append(get_secret)
@mcp.tool(name="whoami", description="Get information about the authenticated user")
async def whoami() -> dict[str, Any]:
"""Get information about the authenticated user from their Google token."""
@@ -182,9 +176,6 @@ def make_tools(mcp: FastMCP) -> list[FunctionTool]:
"access_type": tok.claims.get("access_type"),
}
tools.append(whoami)
# Add some numbered tools for testing tool discovery
for i in range(5):
@mcp.tool(name=f"oauth_tool_{i}", description=f"Test tool number {i}")
@@ -192,10 +183,6 @@ def make_tools(mcp: FastMCP) -> list[FunctionTool]:
"""A numbered test tool."""
return f"Tool {_i} says hello to {name}!"
tools.append(numbered_tool)
return tools
if __name__ == "__main__":
port = int(sys.argv[1] if len(sys.argv) > 1 else "8006")

View File

@@ -2,7 +2,6 @@ import os
import sys
from fastmcp import FastMCP
from fastmcp.server.server import FunctionTool
mcp = FastMCP("My HTTP MCP")
@@ -13,19 +12,15 @@ def hello(name: str) -> str:
return f"Hello, {name}!"
def make_many_tools() -> list[FunctionTool]:
def make_tool(i: int) -> FunctionTool:
def make_many_tools() -> None:
def make_tool(i: int) -> None:
@mcp.tool(name=f"tool_{i}", description=f"Get secret value {i}")
def tool_name(name: str) -> str: # noqa: ARG001
"""Get secret value."""
return f"Secret value {100 - i}!"
return tool_name
tools = []
for i in range(100):
tools.append(make_tool(i))
return tools
make_tool(i)
if __name__ == "__main__":

View File

@@ -15,7 +15,6 @@ from fastapi.responses import Response
from fastmcp import FastMCP
from fastmcp.server.auth.providers.jwt import JWTVerifier
from fastmcp.server.dependencies import get_access_token
from fastmcp.server.server import FunctionTool
from starlette.middleware.base import BaseHTTPMiddleware
# uncomment for debug logs
@@ -37,18 +36,15 @@ Enable authorization code and store the client id and secret.
"""
def make_many_tools(mcp: FastMCP) -> list[FunctionTool]:
def make_tool(i: int) -> FunctionTool:
def make_many_tools(mcp: FastMCP) -> None:
def make_tool(i: int) -> None:
@mcp.tool(name=f"tool_{i}", description=f"Get secret value {i}")
def tool_name(name: str) -> str: # noqa: ARG001
"""Get secret value."""
return f"Secret value {500 - i}!"
return tool_name
tools = []
for i in range(100):
tools.append(make_tool(i))
make_tool(i)
@mcp.tool
async def whoami() -> dict[str, Any]:
@@ -59,9 +55,6 @@ def make_many_tools(mcp: FastMCP) -> list[FunctionTool]:
"claims": tok.claims if tok else {},
}
tools.append(whoami)
return tools
# ---------- FASTAPI APP ----------

View File

@@ -10,7 +10,6 @@ from fastmcp import FastMCP
from fastmcp.server.auth.auth import AccessToken
from fastmcp.server.auth.auth import TokenVerifier
from fastmcp.server.dependencies import get_access_token
from fastmcp.server.server import FunctionTool
# pip install fastmcp bcrypt
@@ -93,19 +92,15 @@ class ApiKeyVerifier(TokenVerifier):
# ---- server -----------------------------------------------------------------
def make_many_tools(mcp: FastMCP) -> list[FunctionTool]:
def make_tool(i: int) -> FunctionTool:
def make_many_tools(mcp: FastMCP) -> None:
def make_tool(i: int) -> None:
@mcp.tool(name=f"tool_{i}", description=f"Get secret value {i}")
def tool_name(name: str) -> str: # noqa: ARG001
"""Get secret value."""
return f"Secret value {400 - i}!"
return tool_name
tools = []
for i in range(100):
tools.append(make_tool(i))
return tools
make_tool(i)
if __name__ == "__main__":

View File

@@ -106,13 +106,13 @@ class TestGuildDataIsolation:
# Create admin user for tenant 1
admin_user1: DATestUser = UserManager.create(
email=f"discord_admin1+{unique}@example.com",
email=f"discord_admin1_{unique}@example.com",
)
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
# Create admin user for tenant 2
admin_user2: DATestUser = UserManager.create(
email=f"discord_admin2+{unique}@example.com",
email=f"discord_admin2_{unique}@example.com",
)
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
@@ -170,10 +170,10 @@ class TestGuildDataIsolation:
# Create admin users for two tenants
admin_user1: DATestUser = UserManager.create(
email=f"discord_list1+{unique}@example.com",
email=f"discord_list1_{unique}@example.com",
)
admin_user2: DATestUser = UserManager.create(
email=f"discord_list2+{unique}@example.com",
email=f"discord_list2_{unique}@example.com",
)
# Create 1 guild in tenant 1
@@ -350,10 +350,10 @@ class TestGuildAccessIsolation:
# Create admin users for two tenants
admin_user1: DATestUser = UserManager.create(
email=f"discord_access1+{unique}@example.com",
email=f"discord_access1_{unique}@example.com",
)
admin_user2: DATestUser = UserManager.create(
email=f"discord_access2+{unique}@example.com",
email=f"discord_access2_{unique}@example.com",
)
# Create a guild in tenant 1

View File

@@ -21,7 +21,7 @@ def test_admin_can_invite_users(reset_multitenant: None) -> None: # noqa: ARG00
# Admin user invites the previously registered and non-registered user
UserManager.invite_user(invited_user.email, admin_user)
UserManager.invite_user(f"{INVITED_BASIC_USER}+{unique}@example.com", admin_user)
UserManager.invite_user(f"{INVITED_BASIC_USER}_{unique}@example.com", admin_user)
# Verify users are in the invited users list
invited_users = UserManager.get_invited_users(admin_user)
@@ -40,7 +40,7 @@ def test_non_registered_user_gets_basic_role(
assert UserManager.is_role(admin_user, UserRole.ADMIN)
# Admin user invites a non-registered user
invited_email = f"{INVITED_BASIC_USER}+{unique}@example.com"
invited_email = f"{INVITED_BASIC_USER}_{unique}@example.com"
UserManager.invite_user(invited_email, admin_user)
# Non-registered user registers
@@ -58,7 +58,7 @@ def test_user_can_accept_invitation(reset_multitenant: None) -> None: # noqa: A
assert UserManager.is_role(admin_user, UserRole.ADMIN)
# Create a user to be invited
invited_user_email = f"invited_user+{unique}@example.com"
invited_user_email = f"invited_user_{unique}@example.com"
# User registers with the same email as the invitation
invited_user: DATestUser = UserManager.create(

View File

@@ -20,13 +20,13 @@ def setup_test_tenants(reset_multitenant: None) -> dict[str, Any]: # noqa: ARG0
unique = uuid4().hex
# Creating an admin user for Tenant 1
admin_user1: DATestUser = UserManager.create(
email=f"admin+{unique}@example.com",
email=f"admin_{unique}@example.com",
)
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
# Create Tenant 2 and its Admin User
admin_user2: DATestUser = UserManager.create(
email=f"admin2+{unique}@example.com",
email=f"admin2_{unique}@example.com",
)
assert UserManager.is_role(admin_user2, UserRole.ADMIN)

View File

@@ -0,0 +1,167 @@
"""
Integration tests for onyx.db.engine.tenant_utils.get_schemas_needing_migration.
These tests require a live database and exercise the function directly,
independent of the alembic migration runner script.
Usage:
pytest tests/integration/multitenant_tests/test_get_schemas_needing_migration.py -v
"""
from __future__ import annotations
import subprocess
import uuid
from collections.abc import Generator
import pytest
from sqlalchemy import text
from sqlalchemy.engine import Engine
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
_BACKEND_DIR = __file__[: __file__.index("/tests/")]
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def engine() -> Engine:
return SqlEngine.get_engine()
@pytest.fixture
def current_head_rev() -> str:
result = subprocess.run(
["alembic", "heads", "--resolve-dependencies"],
cwd=_BACKEND_DIR,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
assert (
result.returncode == 0
), f"alembic heads failed (exit {result.returncode}):\n{result.stdout}"
rev = result.stdout.strip().split()[0]
assert len(rev) > 0
return rev
@pytest.fixture
def tenant_schema_at_head(
engine: Engine, current_head_rev: str
) -> Generator[str, None, None]:
"""Tenant schema with alembic_version already at head — should be excluded."""
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
with engine.connect() as conn:
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
conn.execute(
text(
f'CREATE TABLE "{schema}".alembic_version '
f"(version_num VARCHAR(32) NOT NULL)"
)
)
conn.execute(
text(f'INSERT INTO "{schema}".alembic_version (version_num) VALUES (:rev)'),
{"rev": current_head_rev},
)
conn.commit()
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
@pytest.fixture
def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
"""Tenant schema with no tables — should be included (needs migration)."""
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
with engine.connect() as conn:
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
conn.commit()
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
@pytest.fixture
def tenant_schema_stale_rev(engine: Engine) -> Generator[str, None, None]:
"""Tenant schema with a non-head revision — should be included (needs migration)."""
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
with engine.connect() as conn:
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
conn.execute(
text(
f'CREATE TABLE "{schema}".alembic_version '
f"(version_num VARCHAR(32) NOT NULL)"
)
)
conn.execute(
text(
f'INSERT INTO "{schema}".alembic_version (version_num) '
f"VALUES ('stalerev000000000000')"
)
)
conn.commit()
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_classifies_all_cases(
current_head_rev: str,
tenant_schema_at_head: str,
tenant_schema_empty: str,
tenant_schema_stale_rev: str,
) -> None:
"""Correctly classifies all three schema states:
- at head → excluded
- no table → included (needs migration)
- stale rev → included (needs migration)
"""
all_schemas = [tenant_schema_at_head, tenant_schema_empty, tenant_schema_stale_rev]
result = get_schemas_needing_migration(all_schemas, current_head_rev)
assert tenant_schema_at_head not in result
assert tenant_schema_empty in result
assert tenant_schema_stale_rev in result
def test_idempotent(
current_head_rev: str,
tenant_schema_at_head: str,
tenant_schema_empty: str,
) -> None:
"""Calling the function twice returns the same result.
Verifies that the DROP TABLE IF EXISTS guards correctly clean up temp
tables so a second call succeeds even if the first left state behind.
"""
schemas = [tenant_schema_at_head, tenant_schema_empty]
first = get_schemas_needing_migration(schemas, current_head_rev)
second = get_schemas_needing_migration(schemas, current_head_rev)
assert first == second
def test_empty_input(current_head_rev: str) -> None:
"""An empty input list returns immediately without touching the DB."""
assert get_schemas_needing_migration([], current_head_rev) == []

View File

@@ -4,75 +4,84 @@ import time
import pytest
import requests
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.settings import SettingsManager
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestSettings
from tests.integration.common_utils.test_models import DATestUser
RETENTION_SECONDS = 10
def _run_ttl_cleanup(retention_days: int) -> None:
"""Directly execute TTL cleanup logic, bypassing Celery task infrastructure."""
with get_session_with_current_tenant() as db_session:
old_chat_sessions = get_chat_sessions_older_than(retention_days, db_session)
for user_id, session_id in old_chat_sessions:
with get_session_with_current_tenant() as db_session:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Chat retention tests are enterprise only",
)
def test_chat_retention(reset: None, admin_user: DATestUser) -> None: # noqa: ARG001
def test_chat_retention(
reset: None, admin_user: DATestUser, llm_provider: DATestLLMProvider # noqa: ARG001
) -> None: # noqa: ARG001
"""Test that chat sessions are deleted after the retention period expires."""
# Set chat retention period to 10 seconds
retention_days = 10 / 86400 # 10 seconds in days (10 / 24 / 60 / 60)
retention_days = RETENTION_SECONDS // 86400
settings = DATestSettings(maximum_chat_retention_days=retention_days)
SettingsManager.update_settings(settings, user_performing_action=admin_user)
# Create a chat session
chat_session = ChatSessionManager.create(
persona_id=0,
description="Test chat retention",
user_performing_action=admin_user,
)
# Send a message
ChatSessionManager.send_message(
response = ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message="This message should be deleted soon",
user_performing_action=admin_user,
)
assert (
response.error is None
), f"Chat response should not have an error: {response.error}"
# Verify the chat session exists
chat_history = ChatSessionManager.get_chat_history(
chat_session=chat_session,
user_performing_action=admin_user,
)
assert len(chat_history) > 0, "Chat session should have messages"
# Wait for TTL task to run (give it ~60 seconds)
print("Waiting for chat retention TTL task to run...")
max_wait_time = 60 # maximum time to wait in seconds
start_time = time.time()
# Wait for the retention period to elapse, then directly run TTL cleanup
time.sleep(RETENTION_SECONDS + 2)
_run_ttl_cleanup(retention_days)
# Verify the chat session was deleted
session_deleted = False
try:
chat_history = ChatSessionManager.get_chat_history(
chat_session=chat_session,
user_performing_action=admin_user,
)
session_deleted = len(chat_history) == 0
except requests.exceptions.HTTPError as e:
if e.response.status_code in (404, 400):
session_deleted = True
else:
raise
while not session_deleted and (time.time() - start_time < max_wait_time):
# Check if chat session is deleted
try:
# Attempt to get chat history - this should 404
chat_history = ChatSessionManager.get_chat_history(
chat_session=chat_session,
user_performing_action=admin_user,
)
# If we got no messages or an empty response, session might be deleted
if not chat_history:
session_deleted = True
break
except requests.exceptions.HTTPError as e:
# If we get a 404 or other error, the session is gone
if e.response.status_code in (404, 400):
session_deleted = True
break
raise # Re-raise other errors
# Wait a bit before checking again
time.sleep(5)
print(f"Waited {time.time() - start_time:.1f} seconds for chat deletion...")
# Assert that the chat session was deleted
assert session_deleted, "Chat session was not deleted within the expected time"
assert session_deleted, "Chat session was not deleted after retention period"

View File

@@ -0,0 +1,32 @@
from collections.abc import Generator
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.test_models import DATestUser
CODE_INTERPRETER_URL = f"{API_SERVER_URL}/admin/code-interpreter"
@pytest.fixture
def preserve_code_interpreter_state(
admin_user: DATestUser,
) -> Generator[None, None, None]:
"""Capture the code interpreter enabled state before a test and restore it
afterwards, so that tests that toggle the setting cannot leak state."""
response = requests.get(
CODE_INTERPRETER_URL,
headers=admin_user.headers,
)
response.raise_for_status()
initial_enabled = response.json()["enabled"]
yield
restore = requests.put(
CODE_INTERPRETER_URL,
json={"enabled": initial_enabled},
headers=admin_user.headers,
)
restore.raise_for_status()

View File

@@ -0,0 +1,97 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.test_models import DATestUser
CODE_INTERPRETER_URL = f"{API_SERVER_URL}/admin/code-interpreter"
CODE_INTERPRETER_HEALTH_URL = f"{CODE_INTERPRETER_URL}/health"
def test_get_code_interpreter_health_as_admin(
admin_user: DATestUser,
) -> None:
"""Health endpoint should return a JSON object with a 'healthy' boolean."""
response = requests.get(
CODE_INTERPRETER_HEALTH_URL,
headers=admin_user.headers,
)
assert response.status_code == 200
data = response.json()
assert "healthy" in data
assert isinstance(data["healthy"], bool)
def test_get_code_interpreter_status_as_admin(
admin_user: DATestUser,
) -> None:
"""GET endpoint should return a JSON object with an 'enabled' boolean."""
response = requests.get(
CODE_INTERPRETER_URL,
headers=admin_user.headers,
)
assert response.status_code == 200
data = response.json()
assert "enabled" in data
assert isinstance(data["enabled"], bool)
def test_update_code_interpreter_disable_and_enable(
admin_user: DATestUser,
preserve_code_interpreter_state: None, # noqa: ARG001
) -> None:
"""PUT endpoint should update the enabled flag and persist across reads."""
# Disable
response = requests.put(
CODE_INTERPRETER_URL,
json={"enabled": False},
headers=admin_user.headers,
)
assert response.status_code == 200
# Verify disabled
response = requests.get(
CODE_INTERPRETER_URL,
headers=admin_user.headers,
)
assert response.status_code == 200
assert response.json()["enabled"] is False
# Re-enable
response = requests.put(
CODE_INTERPRETER_URL,
json={"enabled": True},
headers=admin_user.headers,
)
assert response.status_code == 200
# Verify enabled
response = requests.get(
CODE_INTERPRETER_URL,
headers=admin_user.headers,
)
assert response.status_code == 200
assert response.json()["enabled"] is True
def test_code_interpreter_endpoints_require_admin(
basic_user: DATestUser,
) -> None:
"""All code interpreter endpoints should reject non-admin users."""
health_response = requests.get(
CODE_INTERPRETER_HEALTH_URL,
headers=basic_user.headers,
)
assert health_response.status_code == 403
get_response = requests.get(
CODE_INTERPRETER_URL,
headers=basic_user.headers,
)
assert get_response.status_code == 403
put_response = requests.put(
CODE_INTERPRETER_URL,
json={"enabled": True},
headers=basic_user.headers,
)
assert put_response.status_code == 403

View File

@@ -1,195 +0,0 @@
import os
import pytest
import requests
from onyx.configs.constants import MessageType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history is enterprise only",
)
def test_all_stream_chat_message_objects_outputs(reset: None) -> None: # noqa: ARG001
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create connector
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user,
)
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
LLMProviderManager.create(user_performing_action=admin_user)
# SEEDING DOCUMENTS
cc_pair_1.documents = []
cc_pair_1.documents.append(
DocumentManager.seed_doc_with_content(
cc_pair=cc_pair_1,
content="Pablo's favorite color is blue",
api_key=api_key,
)
)
cc_pair_1.documents.append(
DocumentManager.seed_doc_with_content(
cc_pair=cc_pair_1,
content="Chris's favorite color is red",
api_key=api_key,
)
)
cc_pair_1.documents.append(
DocumentManager.seed_doc_with_content(
cc_pair=cc_pair_1,
content="Pika's favorite color is green",
api_key=api_key,
)
)
# TESTING RESPONSE FOR QUESTION 1
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [
{
"message": "What is Pablo's favorite color?",
"role": MessageType.USER.value,
}
],
"persona_id": 0,
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# check that the answer is correct
answer_1 = response_json["answer"]
assert "blue" in answer_1.lower()
# FLAKY - check that the llm selected a document
# assert 0 in response_json["llm_selected_doc_indices"]
# check that the final context documents are correct
# (it should contain all documents because there arent enough to exclude any)
assert 0 in response_json["final_context_doc_indices"]
assert 1 in response_json["final_context_doc_indices"]
assert 2 in response_json["final_context_doc_indices"]
# FLAKY - check that the cited documents are correct
# assert cc_pair_1.documents[0].id in response_json["cited_documents"].values()
# flakiness likely due to non-deterministic rephrasing
# FLAKY - check that the top documents are correct
# assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
print("response 1/3 passed")
# TESTING RESPONSE FOR QUESTION 2
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [
{
"message": "What is Pablo's favorite color?",
"role": MessageType.USER.value,
},
{
"message": answer_1,
"role": MessageType.ASSISTANT.value,
},
{
"message": "What is Chris's favorite color?",
"role": MessageType.USER.value,
},
],
"persona_id": 0,
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# check that the answer is correct
answer_2 = response_json["answer"]
assert "red" in answer_2.lower()
# FLAKY - check that the llm selected a document
# assert 0 in response_json["llm_selected_doc_indices"]
# check that the final context documents are correct
# (it should contain all documents because there arent enough to exclude any)
assert 0 in response_json["final_context_doc_indices"]
assert 1 in response_json["final_context_doc_indices"]
assert 2 in response_json["final_context_doc_indices"]
# FLAKY - check that the cited documents are correct
# assert cc_pair_1.documents[1].id in response_json["cited_documents"].values()
# flakiness likely due to non-deterministic rephrasing
# FLAKY - check that the top documents are correct
# assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[1].id
print("response 2/3 passed")
# TESTING RESPONSE FOR QUESTION 3
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [
{
"message": "What is Pablo's favorite color?",
"role": MessageType.USER.value,
},
{
"message": answer_1,
"role": MessageType.ASSISTANT.value,
},
{
"message": "What is Chris's favorite color?",
"role": MessageType.USER.value,
},
{
"message": answer_2,
"role": MessageType.ASSISTANT.value,
},
{
"message": "What is Pika's favorite color?",
"role": MessageType.USER.value,
},
],
"persona_id": 0,
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# check that the answer is correct
answer_3 = response_json["answer"]
assert "green" in answer_3.lower()
# FLAKY - check that the llm selected a document
# assert 0 in response_json["llm_selected_doc_indices"]
# check that the final context documents are correct
# (it should contain all documents because there arent enough to exclude any)
assert 0 in response_json["final_context_doc_indices"]
assert 1 in response_json["final_context_doc_indices"]
assert 2 in response_json["final_context_doc_indices"]
# FLAKY - check that the cited documents are correct
# assert cc_pair_1.documents[2].id in response_json["cited_documents"].values()
# flakiness likely due to non-deterministic rephrasing
# FLAKY - check that the top documents are correct
# assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id
print("response 3/3 passed")

View File

@@ -1,250 +0,0 @@
import json
import os
import pytest
import requests
from onyx.configs.constants import MessageType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.conftest import DocumentBuilderType
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_send_message_simple_with_history(
reset: None, # noqa: ARG001
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
document_builder: DocumentBuilderType,
) -> None:
# create documents using the document builder
# Create NUM_DOCS number of documents with dummy content
content_list = [f"Document {i} content" for i in range(NUM_DOCS)]
docs = document_builder(content_list)
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [
{
"message": docs[0].content,
"role": MessageType.USER.value,
}
],
"persona_id": 0,
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# Check that the top document is the correct document
assert response_json["top_documents"][0]["document_id"] == docs[0].id
# assert that the metadata is correct
for doc in docs:
found_doc = next(
(x for x in response_json["top_documents"] if x["document_id"] == doc.id),
None,
)
assert found_doc
assert found_doc["metadata"]["document_id"] == doc.id
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_using_reference_docs_with_simple_with_history_api_flow(
reset: None, # noqa: ARG001
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
document_builder: DocumentBuilderType,
) -> None:
# SEEDING DOCUMENTS
docs = document_builder(
[
"Chris's favorite color is blue",
"Hagen's favorite color is red",
"Pablo's favorite color is green",
]
)
# SEINDING MESSAGE 1
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [
{
"message": "What is Pablo's favorite color?",
"role": MessageType.USER.value,
}
],
"persona_id": 0,
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# get the db_doc_id of the top document to use as a search doc id for second message
first_db_doc_id = response_json["top_documents"][0]["db_doc_id"]
# SEINDING MESSAGE 2
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [
{
"message": "What is Pablo's favorite color?",
"role": MessageType.USER.value,
}
],
"persona_id": 0,
"search_doc_ids": [first_db_doc_id],
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# make sure there is an answer
assert response_json["answer"]
# This ensures the the document we think we are referencing when we send the search_doc_ids in the second
# message is the document that we expect it to be
assert response_json["top_documents"][0]["document_id"] == docs[2].id
@pytest.mark.skip(reason="We don't support this anymore with the DR flow :(")
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_send_message_simple_with_history_strict_json(
reset: None, # noqa: ARG001
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
# intentionally not relevant prompt to ensure that the
# structured response format is actually used
"messages": [
{
"message": "What is green?",
"role": MessageType.USER.value,
}
],
"persona_id": 0,
"structured_response_format": {
"type": "json_schema",
"json_schema": {
"name": "presidents",
"schema": {
"type": "object",
"properties": {
"presidents": {
"type": "array",
"items": {"type": "string"},
"description": "List of the first three US presidents",
}
},
"required": ["presidents"],
"additionalProperties": False,
},
"strict": True,
},
},
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# Check that the answer is present
assert "answer" in response_json
assert response_json["answer"] is not None
# helper
def clean_json_string(json_string: str) -> str:
return json_string.strip().removeprefix("```json").removesuffix("```").strip()
# Attempt to parse the answer as JSON
try:
clean_answer = clean_json_string(response_json["answer"])
parsed_answer = json.loads(clean_answer)
# NOTE: do not check content, just the structure
assert isinstance(parsed_answer, dict)
assert "presidents" in parsed_answer
assert isinstance(parsed_answer["presidents"], list)
for president in parsed_answer["presidents"]:
assert isinstance(president, str)
except json.JSONDecodeError:
assert (
False
), f"The answer is not a valid JSON object - '{response_json['answer']}'"
# Check that the answer_citationless is also valid JSON
assert "answer_citationless" in response_json
assert response_json["answer_citationless"] is not None
try:
clean_answer_citationless = clean_json_string(
response_json["answer_citationless"]
)
parsed_answer_citationless = json.loads(clean_answer_citationless)
assert isinstance(parsed_answer_citationless, dict)
except json.JSONDecodeError:
assert False, "The answer_citationless is not a valid JSON object"
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/query/answer-with-citation tests are enterprise only",
)
def test_answer_with_citation_api(
reset: None, # noqa: ARG001
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
document_builder: DocumentBuilderType,
) -> None:
# create docs
docs = document_builder(["Chris' favorite color is green"])
# send a message
response = requests.post(
f"{API_SERVER_URL}/query/answer-with-citation",
json={
"messages": [
{
"message": "What is Chris' favorite color? Make sure to cite the document.",
"role": MessageType.USER.value,
}
],
"persona_id": 0,
},
headers=admin_user.headers,
cookies=admin_user.cookies,
)
assert response.status_code == 200
response_json = response.json()
assert response_json["answer"]
has_correct_citation = False
for citation in response_json["citations"]:
if citation["document_id"] == docs[0].id:
has_correct_citation = True
break
assert has_correct_citation

View File

@@ -2,7 +2,6 @@ import os
import uuid
from datetime import datetime
from datetime import timezone
from unittest.mock import patch
import httpx
import pytest
@@ -12,6 +11,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.mock_connector.connector import EXTERNAL_USER_EMAILS
from onyx.connectors.mock_connector.connector import EXTERNAL_USER_GROUP_IDS
from onyx.connectors.mock_connector.connector import MockConnectorCheckpoint
from onyx.connectors.models import Document
from onyx.connectors.models import InputType
from onyx.db.document import get_documents_by_ids
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -25,128 +25,16 @@ from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
from tests.integration.common_utils.test_document_utils import create_test_document
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission sync is enterprise only",
)
def test_mock_connector_initial_permission_sync(
def _setup_mock_connector(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture,
admin_user: DATestUser,
) -> None:
"""Test that the MockConnector fetches and sets permissions during initial indexing when AccessType.SYNC is used"""
# Set up mock server behavior
doc_uuid = uuid.uuid4()
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [test_doc.model_dump(mode="json")],
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [],
}
],
)
assert response.status_code == 200
# Create CC Pair with SYNC access type to enable permissions during indexing
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-permissions-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
access_type=AccessType.SYNC, # This enables permissions during indexing
user_performing_action=admin_user,
)
# Wait for index attempt to start
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# Wait for index attempt to finish
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# Validate status
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_index_attempt.status == IndexingStatus.SUCCESS
# Verify document was indexed
with get_session_with_current_tenant() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 1
assert documents[0].id == test_doc.id
# Verify no errors occurred
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert len(errors) == 0
# Verify permissions were set during indexing by checking the document in the database
with get_session_with_current_tenant() as db_session:
db_docs = get_documents_by_ids(
db_session=db_session,
document_ids=[test_doc.id],
)
assert len(db_docs) == 1
db_doc = db_docs[0]
assert db_doc.external_user_emails is not None
assert db_doc.external_user_group_ids is not None
# Check the specific permissions that MockConnector sets
assert set(db_doc.external_user_emails) == EXTERNAL_USER_EMAILS
assert set(db_doc.external_user_group_ids) == EXTERNAL_USER_GROUP_IDS
# Verify the document is not public (as set by MockConnector)
assert db_doc.is_public is False
# Verify that the cc_pair was marked as permissions synced
updated_cc_pair_info = CCPairManager.get_single(
cc_pair.id, user_performing_action=admin_user
)
assert updated_cc_pair_info is not None
assert updated_cc_pair_info.last_full_permission_sync is not None
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission sync attempt tracking is enterprise only",
)
def test_permission_sync_attempt_tracking_integration(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture, # noqa: ARG001
admin_user: DATestUser,
) -> None:
"""Test that permission sync attempts are properly tracked during real sync workflows."""
) -> tuple[DATestCCPair, Document]:
"""Common setup: create a test doc, configure mock server, create cc_pair, wait for indexing."""
doc_uuid = uuid.uuid4()
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
@@ -165,7 +53,7 @@ def test_permission_sync_attempt_tracking_integration(
assert response.status_code == 200
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-attempt-tracking-{uuid.uuid4()}",
name=f"mock-connector-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
@@ -187,6 +75,95 @@ def test_permission_sync_attempt_tracking_integration(
user_performing_action=admin_user,
)
finished = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished.status == IndexingStatus.SUCCESS
return cc_pair, test_doc
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission sync is enterprise only",
)
def test_mock_connector_initial_permission_sync(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture,
admin_user: DATestUser,
) -> None:
"""Test that the MockConnector fetches and sets permissions during initial indexing
when AccessType.SYNC is used."""
cc_pair, test_doc = _setup_mock_connector(mock_server_client, admin_user)
with get_session_with_current_tenant() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 1
assert documents[0].id == test_doc.id
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert len(errors) == 0
with get_session_with_current_tenant() as db_session:
db_docs = get_documents_by_ids(
db_session=db_session,
document_ids=[test_doc.id],
)
assert len(db_docs) == 1
db_doc = db_docs[0]
assert db_doc.external_user_emails is not None
assert db_doc.external_user_group_ids is not None
assert set(db_doc.external_user_emails) == EXTERNAL_USER_EMAILS
assert set(db_doc.external_user_group_ids) == EXTERNAL_USER_GROUP_IDS
assert db_doc.is_public is False
# After initial indexing, the beat task detects last_time_perm_sync is None
# and triggers a doc permission sync. Explicitly trigger it to avoid
# waiting for the 30s beat interval.
before = datetime.now(timezone.utc)
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
)
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
number_of_updated_docs=1,
user_performing_action=admin_user,
should_wait_for_group_sync=False,
should_wait_for_vespa_sync=False,
)
updated_cc_pair_info = CCPairManager.get_single(
cc_pair.id, user_performing_action=admin_user
)
assert updated_cc_pair_info is not None
assert updated_cc_pair_info.last_full_permission_sync is not None
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission sync attempt tracking is enterprise only",
)
def test_permission_sync_attempt_tracking_integration(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture, # noqa: ARG001
admin_user: DATestUser,
) -> None:
"""Test that permission sync attempts are properly tracked during real sync workflows."""
cc_pair, _test_doc = _setup_mock_connector(mock_server_client, admin_user)
before = datetime.now(timezone.utc)
CCPairManager.sync(
cc_pair=cc_pair,
@@ -198,6 +175,8 @@ def test_permission_sync_attempt_tracking_integration(
after=before,
number_of_updated_docs=1,
user_performing_action=admin_user,
should_wait_for_group_sync=False,
should_wait_for_vespa_sync=False,
)
with get_session_with_current_tenant() as db_session:
@@ -219,88 +198,6 @@ def test_permission_sync_attempt_tracking_integration(
)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission sync attempt tracking is enterprise only",
)
def test_permission_sync_attempt_tracking_with_mocked_failure(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture, # noqa: ARG001
admin_user: DATestUser,
) -> None:
"""Test that permission sync attempts are properly tracked when sync fails."""
doc_uuid = uuid.uuid4()
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [test_doc.model_dump(mode="json")],
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [],
}
],
)
assert response.status_code == 200
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-attempt-failure-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# Mock the permission sync to force a failure and verify attempt tracking
with patch(
"ee.onyx.background.celery.tasks.doc_permission_syncing.tasks.validate_ccpair_for_user"
) as mock_validate:
mock_validate.side_effect = Exception("Validation failed for testing")
try:
before = datetime.now(timezone.utc)
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
)
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
number_of_updated_docs=0,
user_performing_action=admin_user,
)
except Exception:
pass
with get_session_with_current_tenant() as db_session:
attempt = db_session.execute(
select(DocPermissionSyncAttempt).where(
DocPermissionSyncAttempt.connector_credential_pair_id == cc_pair.id
)
).scalar_one()
assert attempt.status == PermissionSyncStatus.FAILED
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission sync attempt tracking is enterprise only",
@@ -311,45 +208,8 @@ def test_permission_sync_attempt_status_success(
admin_user: DATestUser,
) -> None:
"""Test that permission sync attempts are marked as SUCCESS when sync completes without errors."""
doc_uuid = uuid.uuid4()
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [test_doc.model_dump(mode="json")],
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [],
}
],
)
assert response.status_code == 200
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-success-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
cc_pair, _test_doc = _setup_mock_connector(mock_server_client, admin_user)
before = datetime.now(timezone.utc)
CCPairManager.sync(
@@ -362,6 +222,8 @@ def test_permission_sync_attempt_status_success(
after=before,
number_of_updated_docs=1,
user_performing_action=admin_user,
should_wait_for_group_sync=False,
should_wait_for_vespa_sync=False,
)
with get_session_with_current_tenant() as db_session:

Some files were not shown because too many files have changed in this diff Show More