mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-25 11:45:47 +00:00
Compare commits
6 Commits
ci_python_
...
unify-user
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
165cb46a2a | ||
|
|
ee4d113474 | ||
|
|
129e3698ef | ||
|
|
1ff4fe2d63 | ||
|
|
113db07e6d | ||
|
|
d6d2e4f8fd |
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -8,5 +8,5 @@
|
||||
|
||||
## Additional Options
|
||||
|
||||
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
|
||||
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
161
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
161
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
@@ -1,161 +0,0 @@
|
||||
name: Post-Merge Beta Cherry-Pick
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
outputs:
|
||||
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
|
||||
pr_number: ${{ steps.gate.outputs.pr_number }}
|
||||
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Resolve merged PR and checkbox state
|
||||
id: gate
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
# For the commit that triggered this workflow (HEAD on main), fetch all
|
||||
# associated PRs and keep only the PR that was actually merged into main
|
||||
# with this exact merge commit SHA.
|
||||
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
|
||||
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
|
||||
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
|
||||
|
||||
if [ "${match_count}" -gt 1 ]; then
|
||||
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
|
||||
fi
|
||||
|
||||
if [ -z "$pr_number" ]; then
|
||||
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Read the PR once so we can gate behavior and infer preferred actor.
|
||||
pr_json="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}")"
|
||||
pr_body="$(printf '%s' "$pr_json" | jq -r '.body // ""')"
|
||||
merged_by="$(printf '%s' "$pr_json" | jq -r '.merged_by.login // ""')"
|
||||
|
||||
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
|
||||
echo "merged_by=$merged_by" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
|
||||
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox checked for PR #${pr_number}."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
|
||||
|
||||
- name: Checkout repository
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: true
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Configure git identity
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Create cherry-pick PR to latest release
|
||||
id: run_cherry_pick
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
continue-on-error: true
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
|
||||
run: |
|
||||
set -o pipefail
|
||||
output_file="$(mktemp)"
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
|
||||
exit_code="${PIPESTATUS[0]}"
|
||||
|
||||
if [ "${exit_code}" -eq 0 ]; then
|
||||
echo "status=success" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "status=failure" >> "$GITHUB_OUTPUT"
|
||||
|
||||
reason="command-failed"
|
||||
if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then
|
||||
reason="merge-conflict"
|
||||
fi
|
||||
echo "reason=${reason}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
{
|
||||
echo "details<<EOF"
|
||||
tail -n 40 "$output_file"
|
||||
echo "EOF"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
run: |
|
||||
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
needs:
|
||||
- cherry-pick-to-latest-release
|
||||
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build cherry-pick failure summary
|
||||
id: failure-summary
|
||||
env:
|
||||
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
|
||||
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
|
||||
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
|
||||
run: |
|
||||
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
|
||||
|
||||
reason_text="cherry-pick command failed"
|
||||
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
|
||||
reason_text="merge conflict during cherry-pick"
|
||||
fi
|
||||
|
||||
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
|
||||
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
|
||||
if [ -n "${details_excerpt}" ]; then
|
||||
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
|
||||
fi
|
||||
|
||||
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Notify #cherry-pick-prs about cherry-pick failure
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
|
||||
title: "🚨 Automated Cherry-Pick Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
Normal file
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Require beta cherry-pick consideration
|
||||
concurrency:
|
||||
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
beta-cherrypick-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Check PR body for beta cherry-pick consideration
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
|
||||
echo "Cherry-pick consideration box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
|
||||
exit 1
|
||||
@@ -116,6 +116,7 @@ jobs:
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
DISABLE_TELEMETRY=true
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=true
|
||||
EOF
|
||||
|
||||
@@ -21,14 +21,15 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import NamedTuple
|
||||
from typing import List, NamedTuple
|
||||
|
||||
from alembic.config import Config
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import is_valid_schema_name
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
|
||||
@@ -104,6 +105,56 @@ def get_head_revision() -> str | None:
|
||||
return script.get_current_head()
|
||||
|
||||
|
||||
def get_schemas_needing_migration(
|
||||
tenant_schemas: List[str], head_rev: str
|
||||
) -> List[str]:
|
||||
"""Return only schemas whose current alembic version is not at head."""
|
||||
if not tenant_schemas:
|
||||
return []
|
||||
|
||||
engine = SqlEngine.get_engine()
|
||||
|
||||
with engine.connect() as conn:
|
||||
# Find which schemas actually have an alembic_version table
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"SELECT table_schema FROM information_schema.tables "
|
||||
"WHERE table_name = 'alembic_version' "
|
||||
"AND table_schema = ANY(:schemas)"
|
||||
),
|
||||
{"schemas": tenant_schemas},
|
||||
)
|
||||
schemas_with_table = set(row[0] for row in rows)
|
||||
|
||||
# Schemas without the table definitely need migration
|
||||
needs_migration = [s for s in tenant_schemas if s not in schemas_with_table]
|
||||
|
||||
if not schemas_with_table:
|
||||
return needs_migration
|
||||
|
||||
# Validate schema names before interpolating into SQL
|
||||
for schema in schemas_with_table:
|
||||
if not is_valid_schema_name(schema):
|
||||
raise ValueError(f"Invalid schema name: {schema}")
|
||||
|
||||
# Single query to get every schema's current revision at once.
|
||||
# Use integer tags instead of interpolating schema names into
|
||||
# string literals to avoid quoting issues.
|
||||
schema_list = list(schemas_with_table)
|
||||
union_parts = [
|
||||
f'SELECT {i} AS idx, version_num FROM "{schema}".alembic_version'
|
||||
for i, schema in enumerate(schema_list)
|
||||
]
|
||||
rows = conn.execute(text(" UNION ALL ".join(union_parts)))
|
||||
version_by_schema = {schema_list[row[0]]: row[1] for row in rows}
|
||||
|
||||
needs_migration.extend(
|
||||
s for s in schemas_with_table if version_by_schema.get(s) != head_rev
|
||||
)
|
||||
|
||||
return needs_migration
|
||||
|
||||
|
||||
def run_migrations_parallel(
|
||||
schemas: list[str],
|
||||
max_workers: int,
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
"""code interpreter seed
|
||||
|
||||
Revision ID: 07b98176f1de
|
||||
Revises: 7cb492013621
|
||||
Create Date: 2026-02-23 15:55:07.606784
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "07b98176f1de"
|
||||
down_revision = "7cb492013621"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Seed the single instance of code_interpreter_server
|
||||
# NOTE: There should only exist at most and at minimum 1 code_interpreter_server row
|
||||
op.execute(
|
||||
sa.text("INSERT INTO code_interpreter_server (server_enabled) VALUES (true)")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(sa.text("DELETE FROM code_interpreter_server"))
|
||||
@@ -1,48 +0,0 @@
|
||||
"""add enterprise and name fields to scim_user_mapping
|
||||
|
||||
Revision ID: 7616121f6e97
|
||||
Revises: 07b98176f1de
|
||||
Create Date: 2026-02-23 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7616121f6e97"
|
||||
down_revision = "07b98176f1de"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("department", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("manager", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("given_name", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("family_name", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("scim_emails_json", sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("scim_user_mapping", "scim_emails_json")
|
||||
op.drop_column("scim_user_mapping", "family_name")
|
||||
op.drop_column("scim_user_mapping", "given_name")
|
||||
op.drop_column("scim_user_mapping", "manager")
|
||||
op.drop_column("scim_user_mapping", "department")
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add needs_persona_sync to user_file
|
||||
|
||||
Revision ID: 8ffcc2bcfc11
|
||||
Revises: 7cb492013621
|
||||
Create Date: 2026-02-23 10:48:48.343826
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8ffcc2bcfc11"
|
||||
down_revision = "7cb492013621"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"needs_persona_sync",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user_file", "needs_persona_sync")
|
||||
@@ -127,14 +127,9 @@ class ScimDAL(DAL):
|
||||
self,
|
||||
external_id: str,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
)
|
||||
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
return mapping
|
||||
@@ -253,11 +248,11 @@ class ScimDAL(DAL):
|
||||
scim_filter: ScimFilter | None,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[tuple[User, ScimUserMapping | None]], int]:
|
||||
) -> tuple[list[tuple[User, str | None]], int]:
|
||||
"""Query users with optional SCIM filter and pagination.
|
||||
|
||||
Returns:
|
||||
A tuple of (list of (user, mapping) pairs, total_count).
|
||||
A tuple of (list of (user, external_id) pairs, total_count).
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
@@ -297,104 +292,33 @@ class ScimDAL(DAL):
|
||||
users = list(
|
||||
self._session.scalars(
|
||||
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
).all()
|
||||
)
|
||||
|
||||
# Batch-fetch SCIM mappings to avoid N+1 queries
|
||||
mapping_map = self._get_user_mappings_batch([u.id for u in users])
|
||||
return [(u, mapping_map.get(u.id)) for u in users], total
|
||||
# Batch-fetch external IDs to avoid N+1 queries
|
||||
ext_id_map = self._get_user_external_ids([u.id for u in users])
|
||||
return [(u, ext_id_map.get(u.id)) for u in users], total
|
||||
|
||||
def sync_user_external_id(
|
||||
self,
|
||||
user_id: UUID,
|
||||
new_external_id: str | None,
|
||||
scim_username: str | None = None,
|
||||
) -> None:
|
||||
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user."""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
else:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
)
|
||||
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
|
||||
def _get_user_mappings_batch(
|
||||
self, user_ids: list[UUID]
|
||||
) -> dict[UUID, ScimUserMapping]:
|
||||
"""Batch-fetch SCIM user mappings keyed by user ID."""
|
||||
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
|
||||
"""Batch-fetch external IDs for a list of user IDs."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
mappings = self._session.scalars(
|
||||
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
|
||||
).all()
|
||||
return {m.user_id: m for m in mappings}
|
||||
|
||||
def get_user_groups(self, user_id: UUID) -> list[tuple[int, str]]:
|
||||
"""Get groups a user belongs to as ``(group_id, group_name)`` pairs.
|
||||
|
||||
Excludes groups marked for deletion.
|
||||
"""
|
||||
rels = self._session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_id == user_id)
|
||||
).all()
|
||||
|
||||
group_ids = [r.user_group_id for r in rels]
|
||||
if not group_ids:
|
||||
return []
|
||||
|
||||
groups = self._session.scalars(
|
||||
select(UserGroup).where(
|
||||
UserGroup.id.in_(group_ids),
|
||||
UserGroup.is_up_for_deletion.is_(False),
|
||||
)
|
||||
).all()
|
||||
return [(g.id, g.name) for g in groups]
|
||||
|
||||
def get_users_groups_batch(
|
||||
self, user_ids: list[UUID]
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Batch-fetch group memberships for multiple users.
|
||||
|
||||
Returns a mapping of ``user_id → [(group_id, group_name), ...]``.
|
||||
Avoids N+1 queries when building user list responses.
|
||||
"""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
rels = self._session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_id.in_(user_ids))
|
||||
).all()
|
||||
|
||||
group_ids = list({r.user_group_id for r in rels})
|
||||
if not group_ids:
|
||||
return {}
|
||||
|
||||
groups = self._session.scalars(
|
||||
select(UserGroup).where(
|
||||
UserGroup.id.in_(group_ids),
|
||||
UserGroup.is_up_for_deletion.is_(False),
|
||||
)
|
||||
).all()
|
||||
groups_by_id = {g.id: g.name for g in groups}
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {}
|
||||
for r in rels:
|
||||
if r.user_id and r.user_group_id in groups_by_id:
|
||||
result.setdefault(r.user_id, []).append(
|
||||
(r.user_group_id, groups_by_id[r.user_group_id])
|
||||
)
|
||||
return result
|
||||
return {m.user_id: m.external_id for m in mappings}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group mapping operations
|
||||
@@ -559,13 +483,9 @@ class ScimDAL(DAL):
|
||||
if not user_ids:
|
||||
return []
|
||||
|
||||
users = (
|
||||
self._session.scalars(
|
||||
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
users = self._session.scalars(
|
||||
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
|
||||
).all()
|
||||
users_by_id = {u.id: u for u in users}
|
||||
|
||||
return [
|
||||
@@ -584,13 +504,9 @@ class ScimDAL(DAL):
|
||||
"""
|
||||
if not uuids:
|
||||
return []
|
||||
existing_users = (
|
||||
self._session.scalars(
|
||||
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
existing_users = self._session.scalars(
|
||||
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
|
||||
).all()
|
||||
existing_ids = {u.id for u in existing_users}
|
||||
return [uid for uid in uuids if uid not in existing_ids]
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ class SendSearchQueryRequest(BaseModel):
|
||||
filters: BaseFilters | None = None
|
||||
num_docs_fed_to_llm_selection: int | None = None
|
||||
run_query_expansion: bool = False
|
||||
num_hits: int = 30
|
||||
num_hits: int = 50
|
||||
|
||||
include_content: bool = False
|
||||
stream: bool = False
|
||||
|
||||
@@ -26,10 +26,12 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimError
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
@@ -39,8 +41,6 @@ from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
@@ -53,6 +53,7 @@ from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
|
||||
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
|
||||
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
|
||||
@@ -62,18 +63,6 @@ scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
_pw_helper = PasswordHelper()
|
||||
|
||||
|
||||
def _get_provider(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
) -> ScimProvider:
|
||||
"""Resolve the SCIM provider for the current request.
|
||||
|
||||
Currently returns OktaProvider for all requests. When multi-provider
|
||||
support is added (ENG-3652), this will resolve based on token metadata
|
||||
or tenant configuration — no endpoint changes required.
|
||||
"""
|
||||
return get_default_provider()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Discovery Endpoints (unauthenticated)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -111,6 +100,28 @@ def _scim_error_response(status: int, detail: str) -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
|
||||
"""Convert an Onyx User to a SCIM User resource representation."""
|
||||
name = None
|
||||
if user.personal_name:
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
name = ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
return ScimUserResource(
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=user.email,
|
||||
name=name,
|
||||
emails=[ScimEmail(value=user.email, type="work", primary=True)],
|
||||
active=user.is_active,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
@@ -144,10 +155,9 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
# Build from givenName/familyName first — IdPs like Okta may send a stale
|
||||
# ``formatted`` value while updating the individual name components.
|
||||
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
|
||||
return parts or name.formatted
|
||||
return name.formatted or " ".join(
|
||||
part for part in [name.givenName, name.familyName] if part
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -161,7 +171,6 @@ def list_users(
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
@@ -174,19 +183,12 @@ def list_users(
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
try:
|
||||
users_with_mappings, total = dal.list_users(scim_filter, startIndex, count)
|
||||
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
user_groups_map = dal.get_users_groups_batch([u.id for u, _ in users_with_mappings])
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=user_groups_map.get(user.id, []),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
)
|
||||
for user, mapping in users_with_mappings
|
||||
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
@@ -201,7 +203,6 @@ def list_users(
|
||||
def get_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Get a single user by ID."""
|
||||
@@ -214,26 +215,20 @@ def get_user(
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
)
|
||||
return _user_to_scim(user, mapping.external_id if mapping else None)
|
||||
|
||||
|
||||
@scim_router.post("/Users", status_code=201, response_model=None)
|
||||
def create_user(
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Create a new user from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
email = user_resource.userName.strip()
|
||||
email = user_resource.userName.strip().lower()
|
||||
|
||||
# externalId is how the IdP correlates this user on subsequent requests.
|
||||
# Without it, the IdP can't find the user and will try to re-create,
|
||||
@@ -269,14 +264,11 @@ def create_user(
|
||||
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id, user_id=user.id, scim_username=scim_username
|
||||
)
|
||||
dal.create_user_mapping(external_id=external_id, user_id=user.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return provider.build_user_resource(user, external_id, scim_username=scim_username)
|
||||
return _user_to_scim(user, external_id)
|
||||
|
||||
|
||||
@scim_router.put("/Users/{user_id}", response_model=None)
|
||||
@@ -284,7 +276,6 @@ def replace_user(
|
||||
user_id: str,
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Replace a user entirely (RFC 7644 §3.5.1)."""
|
||||
@@ -302,27 +293,19 @@ def replace_user(
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=user_resource.userName.strip(),
|
||||
email=user_resource.userName.strip().lower(),
|
||||
is_active=user_resource.active,
|
||||
personal_name=personal_name,
|
||||
personal_name=_scim_name_to_str(user_resource.name),
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
dal.sync_user_external_id(user.id, new_external_id, scim_username=scim_username)
|
||||
dal.sync_user_external_id(user.id, new_external_id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
)
|
||||
return _user_to_scim(user, new_external_id)
|
||||
|
||||
|
||||
@scim_router.patch("/Users/{user_id}", response_model=None)
|
||||
@@ -330,7 +313,6 @@ def patch_user(
|
||||
user_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Partially update a user (RFC 7644 §3.5.2).
|
||||
@@ -348,19 +330,11 @@ def patch_user(
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
current_scim_username = mapping.scim_username if mapping else None
|
||||
|
||||
current = provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=current_scim_username,
|
||||
)
|
||||
current = _user_to_scim(user, external_id)
|
||||
|
||||
try:
|
||||
patched = apply_user_patch(
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
)
|
||||
patched = apply_user_patch(patch_request.Operations, current)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
@@ -371,40 +345,22 @@ def patch_user(
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Track the scim_username — if userName was patched, update it
|
||||
new_scim_username = patched.userName.strip() if patched.userName else None
|
||||
|
||||
# If displayName was explicitly patched (different from the original), use
|
||||
# it as personal_name directly. Otherwise, derive from name components.
|
||||
personal_name: str | None
|
||||
if patched.displayName and patched.displayName != current.displayName:
|
||||
personal_name = patched.displayName
|
||||
else:
|
||||
personal_name = _scim_name_to_str(patched.name)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=(
|
||||
patched.userName.strip()
|
||||
if patched.userName.strip().lower() != user.email.lower()
|
||||
patched.userName.strip().lower()
|
||||
if patched.userName.lower() != user.email
|
||||
else None
|
||||
),
|
||||
is_active=patched.active if patched.active != user.is_active else None,
|
||||
personal_name=personal_name,
|
||||
personal_name=_scim_name_to_str(patched.name),
|
||||
)
|
||||
|
||||
dal.sync_user_external_id(
|
||||
user.id, patched.externalId, scim_username=new_scim_username
|
||||
)
|
||||
dal.sync_user_external_id(user.id, patched.externalId)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
)
|
||||
return _user_to_scim(user, patched.externalId)
|
||||
|
||||
|
||||
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
|
||||
@@ -442,6 +398,24 @@ def delete_user(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _group_to_scim(
|
||||
group: UserGroup,
|
||||
members: list[tuple[UUID, str | None]],
|
||||
external_id: str | None = None,
|
||||
) -> ScimGroupResource:
|
||||
"""Convert an Onyx UserGroup to a SCIM Group resource."""
|
||||
scim_members = [
|
||||
ScimGroupMember(value=str(uid), display=email) for uid, email in members
|
||||
]
|
||||
return ScimGroupResource(
|
||||
id=str(group.id),
|
||||
externalId=external_id,
|
||||
displayName=group.name,
|
||||
members=scim_members,
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
|
||||
"""Parse *group_id* as int, look up the group, or return a 404 error."""
|
||||
try:
|
||||
@@ -500,7 +474,6 @@ def list_groups(
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
@@ -518,7 +491,7 @@ def list_groups(
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
provider.build_group_resource(group, dal.get_group_members(group.id), ext_id)
|
||||
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
|
||||
for group, ext_id in groups_with_ext_ids
|
||||
]
|
||||
|
||||
@@ -534,7 +507,6 @@ def list_groups(
|
||||
def get_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Get a single group by ID."""
|
||||
@@ -549,16 +521,13 @@ def get_group(
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
members = dal.get_group_members(group.id)
|
||||
|
||||
return provider.build_group_resource(
|
||||
group, members, mapping.external_id if mapping else None
|
||||
)
|
||||
return _group_to_scim(group, members, mapping.external_id if mapping else None)
|
||||
|
||||
|
||||
@scim_router.post("/Groups", status_code=201, response_model=None)
|
||||
def create_group(
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Create a new group from a SCIM provisioning request."""
|
||||
@@ -596,7 +565,7 @@ def create_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(db_group.id)
|
||||
return provider.build_group_resource(db_group, members, external_id)
|
||||
return _group_to_scim(db_group, members, external_id)
|
||||
|
||||
|
||||
@scim_router.put("/Groups/{group_id}", response_model=None)
|
||||
@@ -604,7 +573,6 @@ def replace_group(
|
||||
group_id: str,
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Replace a group entirely (RFC 7644 §3.5.1)."""
|
||||
@@ -627,7 +595,7 @@ def replace_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return provider.build_group_resource(group, members, group_resource.externalId)
|
||||
return _group_to_scim(group, members, group_resource.externalId)
|
||||
|
||||
|
||||
@scim_router.patch("/Groups/{group_id}", response_model=None)
|
||||
@@ -635,7 +603,6 @@ def patch_group(
|
||||
group_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Partially update a group (RFC 7644 §3.5.2).
|
||||
@@ -654,11 +621,11 @@ def patch_group(
|
||||
external_id = mapping.external_id if mapping else None
|
||||
|
||||
current_members = dal.get_group_members(group.id)
|
||||
current = provider.build_group_resource(group, current_members, external_id)
|
||||
current = _group_to_scim(group, current_members, external_id)
|
||||
|
||||
try:
|
||||
patched, added_ids, removed_ids = apply_group_patch(
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
patch_request.Operations, current
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
@@ -685,7 +652,7 @@ def patch_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return provider.build_group_resource(group, members, patched.externalId)
|
||||
return _group_to_scim(group, members, patched.externalId)
|
||||
|
||||
|
||||
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
|
||||
|
||||
@@ -63,13 +63,6 @@ class ScimMeta(BaseModel):
|
||||
location: str | None = None
|
||||
|
||||
|
||||
class ScimUserGroupRef(BaseModel):
|
||||
"""Group reference within a User resource (RFC 7643 §4.1.2, read-only)."""
|
||||
|
||||
value: str
|
||||
display: str | None = None
|
||||
|
||||
|
||||
class ScimUserResource(BaseModel):
|
||||
"""SCIM User resource representation (RFC 7643 §4.1).
|
||||
|
||||
@@ -83,10 +76,8 @@ class ScimUserResource(BaseModel):
|
||||
externalId: str | None = None # IdP's identifier for this user
|
||||
userName: str # Typically the user's email address
|
||||
name: ScimName | None = None
|
||||
displayName: str | None = None
|
||||
emails: list[ScimEmail] = Field(default_factory=list)
|
||||
active: bool = True
|
||||
groups: list[ScimUserGroupRef] = Field(default_factory=list)
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
@@ -130,40 +121,12 @@ class ScimPatchOperationType(str, Enum):
|
||||
REMOVE = "remove"
|
||||
|
||||
|
||||
class ScimPatchResourceValue(BaseModel):
|
||||
"""Partial resource dict for path-less PATCH replace operations.
|
||||
|
||||
When an IdP sends a PATCH without a ``path``, the ``value`` is a dict
|
||||
of resource attributes to set. IdPs may include read-only fields
|
||||
(``id``, ``schemas``, ``meta``) alongside actual changes — these are
|
||||
stripped by the provider's ``ignored_patch_paths`` before processing.
|
||||
|
||||
``extra="allow"`` lets unknown attributes pass through so the patch
|
||||
handler can decide what to do with them (ignore or reject).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
active: bool | None = None
|
||||
userName: str | None = None
|
||||
displayName: str | None = None
|
||||
externalId: str | None = None
|
||||
name: ScimName | None = None
|
||||
members: list[ScimGroupMember] | None = None
|
||||
id: str | None = None
|
||||
schemas: list[str] | None = None
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
ScimPatchValue = str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None
|
||||
|
||||
|
||||
class ScimPatchOperation(BaseModel):
|
||||
"""Single PATCH operation (RFC 7644 §3.5.2)."""
|
||||
|
||||
op: ScimPatchOperationType
|
||||
path: str | None = None
|
||||
value: ScimPatchValue = None
|
||||
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
|
||||
|
||||
|
||||
class ScimPatchRequest(BaseModel):
|
||||
|
||||
@@ -16,12 +16,9 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimPatchValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
|
||||
@@ -44,15 +41,9 @@ _MEMBER_FILTER_RE = re.compile(
|
||||
def apply_user_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimUserResource,
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> ScimUserResource:
|
||||
"""Apply SCIM PATCH operations to a user resource.
|
||||
|
||||
Args:
|
||||
operations: The PATCH operations to apply.
|
||||
current: The current user resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns a new ``ScimUserResource`` with the modifications applied.
|
||||
The original object is not mutated.
|
||||
|
||||
@@ -64,9 +55,9 @@ def apply_user_patch(
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
_apply_user_replace(op, data, name_data)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
_apply_user_replace(op, data, name_data)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on User resource"
|
||||
@@ -80,34 +71,30 @@ def _apply_user_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace/add operation to user data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
# No path — value is a resource dict of top-level attributes to set
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
for key, val in op.value.model_dump(exclude_unset=True).items():
|
||||
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
|
||||
# No path — value is a dict of top-level attributes to set
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
_set_user_field(key.lower(), val, data, name_data)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
_set_user_field(path, op.value, data, name_data, ignored_paths)
|
||||
_set_user_field(path, op.value, data, name_data)
|
||||
|
||||
|
||||
def _set_user_field(
|
||||
path: str,
|
||||
value: ScimPatchValue,
|
||||
value: str | bool | dict | list | None,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Set a single field on user data by SCIM path."""
|
||||
if path in ignored_paths:
|
||||
return
|
||||
elif path == "active":
|
||||
if path == "active":
|
||||
data["active"] = value
|
||||
elif path == "username":
|
||||
data["userName"] = value
|
||||
@@ -120,7 +107,7 @@ def _set_user_field(
|
||||
elif path == "name.formatted":
|
||||
name_data["formatted"] = value
|
||||
elif path == "displayname":
|
||||
data["displayName"] = value
|
||||
# Some IdPs send displayName on users; map to formatted name
|
||||
name_data["formatted"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
|
||||
@@ -129,15 +116,9 @@ def _set_user_field(
|
||||
def apply_group_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimGroupResource,
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> tuple[ScimGroupResource, list[str], list[str]]:
|
||||
"""Apply SCIM PATCH operations to a group resource.
|
||||
|
||||
Args:
|
||||
operations: The PATCH operations to apply.
|
||||
current: The current group resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns:
|
||||
A tuple of (modified group, added member IDs, removed member IDs).
|
||||
The caller uses the member ID lists to update the database.
|
||||
@@ -152,9 +133,7 @@ def apply_group_patch(
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_group_replace(
|
||||
op, data, current_members, added_ids, removed_ids, ignored_paths
|
||||
)
|
||||
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_group_add(op, current_members, added_ids)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
@@ -175,48 +154,38 @@ def _apply_group_replace(
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace operation to group data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
dumped = op.value.model_dump(exclude_unset=True)
|
||||
for key, val in dumped.items():
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
if key.lower() == "members":
|
||||
_replace_members(val, current_members, added_ids, removed_ids)
|
||||
else:
|
||||
_set_group_field(key.lower(), val, data, ignored_paths)
|
||||
_set_group_field(key.lower(), val, data)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
if path == "members":
|
||||
_replace_members(
|
||||
_members_to_dicts(op.value), current_members, added_ids, removed_ids
|
||||
)
|
||||
_replace_members(op.value, current_members, added_ids, removed_ids)
|
||||
return
|
||||
|
||||
_set_group_field(path, op.value, data, ignored_paths)
|
||||
|
||||
|
||||
def _members_to_dicts(
|
||||
value: str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None,
|
||||
) -> list[dict]:
|
||||
"""Convert a member list value to a list of dicts for internal processing."""
|
||||
if not isinstance(value, list):
|
||||
raise ScimPatchError("Replace members requires a list value")
|
||||
return [m.model_dump(exclude_none=True) for m in value]
|
||||
_set_group_field(path, op.value, data)
|
||||
|
||||
|
||||
def _replace_members(
|
||||
value: list[dict],
|
||||
value: str | list | dict | bool | None,
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Replace the entire group member list."""
|
||||
if not isinstance(value, list):
|
||||
raise ScimPatchError("Replace members requires a list value")
|
||||
|
||||
old_ids = {m["value"] for m in current_members}
|
||||
new_ids = {m.get("value", "") for m in value}
|
||||
|
||||
@@ -228,14 +197,11 @@ def _replace_members(
|
||||
|
||||
def _set_group_field(
|
||||
path: str,
|
||||
value: ScimPatchValue,
|
||||
value: str | bool | dict | list | None,
|
||||
data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Set a single field on group data by SCIM path."""
|
||||
if path in ignored_paths:
|
||||
return
|
||||
elif path == "displayname":
|
||||
if path == "displayname":
|
||||
data["displayName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
@@ -257,10 +223,8 @@ def _apply_group_add(
|
||||
if not isinstance(op.value, list):
|
||||
raise ScimPatchError("Add members requires a list value")
|
||||
|
||||
member_dicts = [m.model_dump(exclude_none=True) for m in op.value]
|
||||
|
||||
existing_ids = {m["value"] for m in members}
|
||||
for member_data in member_dicts:
|
||||
for member_data in op.value:
|
||||
member_id = member_data.get("value", "")
|
||||
if member_id and member_id not in existing_ids:
|
||||
members.append(member_data)
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
"""Base SCIM provider abstraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
|
||||
|
||||
class ScimProvider(ABC):
|
||||
"""Base class for provider-specific SCIM behavior.
|
||||
|
||||
Subclass this to handle IdP-specific quirks. The base class provides
|
||||
RFC 7643-compliant response builders that populate all standard fields.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Short identifier for this provider (e.g. ``"okta"``)."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
"""SCIM attribute paths to silently skip in PATCH value-object dicts.
|
||||
|
||||
IdPs may include read-only or meta fields alongside actual changes
|
||||
(e.g. Okta sends ``{"id": "...", "active": false}``). Paths listed
|
||||
here are silently dropped instead of raising an error.
|
||||
"""
|
||||
...
|
||||
|
||||
def build_user_resource(
|
||||
self,
|
||||
user: User,
|
||||
external_id: str | None = None,
|
||||
groups: list[tuple[int, str]] | None = None,
|
||||
scim_username: str | None = None,
|
||||
) -> ScimUserResource:
|
||||
"""Build a SCIM User response from an Onyx User.
|
||||
|
||||
Args:
|
||||
user: The Onyx user model.
|
||||
external_id: The IdP's external identifier for this user.
|
||||
groups: List of ``(group_id, group_name)`` tuples for the
|
||||
``groups`` read-only attribute. Pass ``None`` or ``[]``
|
||||
for newly-created users.
|
||||
scim_username: The original-case userName from the IdP. Falls
|
||||
back to ``user.email`` (lowercase) when not available.
|
||||
"""
|
||||
group_refs = [
|
||||
ScimUserGroupRef(value=str(gid), display=gname)
|
||||
for gid, gname in (groups or [])
|
||||
]
|
||||
|
||||
# Use original-case userName if stored, otherwise fall back to the
|
||||
# lowercased email from the User model.
|
||||
username = scim_username or user.email
|
||||
|
||||
return ScimUserResource(
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=username,
|
||||
name=self._build_scim_name(user),
|
||||
displayName=user.personal_name,
|
||||
emails=[ScimEmail(value=username, type="work", primary=True)],
|
||||
active=user.is_active,
|
||||
groups=group_refs,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
def build_group_resource(
|
||||
self,
|
||||
group: UserGroup,
|
||||
members: list[tuple[UUID, str | None]],
|
||||
external_id: str | None = None,
|
||||
) -> ScimGroupResource:
|
||||
"""Build a SCIM Group response from an Onyx UserGroup."""
|
||||
scim_members = [
|
||||
ScimGroupMember(value=str(uid), display=email) for uid, email in members
|
||||
]
|
||||
return ScimGroupResource(
|
||||
id=str(group.id),
|
||||
externalId=external_id,
|
||||
displayName=group.name,
|
||||
members=scim_members,
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_scim_name(user: User) -> ScimName | None:
|
||||
"""Extract SCIM name components from a user's personal name."""
|
||||
if not user.personal_name:
|
||||
return None
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
return ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
|
||||
def get_default_provider() -> ScimProvider:
|
||||
"""Return the default SCIM provider.
|
||||
|
||||
Currently returns ``OktaProvider`` since Okta is the primary supported
|
||||
IdP. When provider detection is added (via token metadata or tenant
|
||||
config), this can be replaced with dynamic resolution.
|
||||
"""
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
return OktaProvider()
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Okta SCIM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
|
||||
class OktaProvider(ScimProvider):
|
||||
"""Okta SCIM provider.
|
||||
|
||||
Okta behavioral notes:
|
||||
- Uses ``PATCH {"active": false}`` for deprovisioning (not DELETE)
|
||||
- Sends path-less PATCH with value dicts containing extra fields
|
||||
(``id``, ``schemas``)
|
||||
- Expects ``displayName`` and ``groups`` in user responses
|
||||
- Only uses ``eq`` operator for ``userName`` filter
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "okta"
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return frozenset({"id", "schemas", "meta"})
|
||||
@@ -277,32 +277,13 @@ def verify_email_domain(email: str) -> None:
|
||||
detail="Email is not valid",
|
||||
)
|
||||
|
||||
local_part, domain = email.split("@")
|
||||
domain = domain.lower()
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
|
||||
if domain == "googlemail.com":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
|
||||
)
|
||||
|
||||
if "+" in local_part and domain != "onyx.app":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
|
||||
},
|
||||
)
|
||||
domain = email.split("@")[-1].lower()
|
||||
|
||||
# Check if email uses a disposable/temporary domain
|
||||
if is_disposable_email(email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
|
||||
},
|
||||
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
|
||||
)
|
||||
|
||||
# Check domain whitelist if configured
|
||||
|
||||
@@ -22,6 +22,7 @@ from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import METADATA_SUFFIX
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER
|
||||
@@ -276,6 +277,7 @@ def transform_vespa_chunks_to_opensearch_chunks(
|
||||
)
|
||||
)
|
||||
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
|
||||
personas: list[int] | None = vespa_chunk.get(PERSONAS)
|
||||
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
|
||||
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
|
||||
|
||||
@@ -325,6 +327,7 @@ def transform_vespa_chunks_to_opensearch_chunks(
|
||||
metadata_suffix=metadata_suffix,
|
||||
document_sets=document_sets,
|
||||
user_projects=user_projects,
|
||||
personas=personas,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
tenant_id=tenant_state,
|
||||
|
||||
@@ -5,13 +5,12 @@ from uuid import UUID
|
||||
|
||||
import httpx
|
||||
import sqlalchemy as sa
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
@@ -26,14 +25,12 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
@@ -79,58 +76,10 @@ def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_queued_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
|
||||
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
|
||||
return celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
|
||||
)
|
||||
|
||||
|
||||
def enqueue_user_file_project_sync_task(
|
||||
*,
|
||||
celery_app: Celery,
|
||||
redis_client: Redis,
|
||||
user_file_id: str | UUID,
|
||||
tenant_id: str,
|
||||
priority: OnyxCeleryPriority = OnyxCeleryPriority.HIGH,
|
||||
) -> bool:
|
||||
"""Enqueue a project-sync task if no matching queued task already exists."""
|
||||
queued_key = _user_file_project_sync_queued_key(user_file_id)
|
||||
|
||||
# NX+EX gives us atomic dedupe and a self-healing TTL.
|
||||
queued_guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
nx=True,
|
||||
ex=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
|
||||
)
|
||||
if not queued_guard_set:
|
||||
return False
|
||||
|
||||
try:
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=priority,
|
||||
expires=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
# Roll back the queued guard if task publish fails.
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0))
|
||||
def _visit_chunks(
|
||||
*,
|
||||
@@ -684,8 +633,8 @@ def process_single_user_file_delete(
|
||||
ignore_result=True,
|
||||
)
|
||||
def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files needing project sync and enqueue per-file tasks."""
|
||||
task_logger.info("Starting")
|
||||
"""Scan for user files with PROJECT_SYNC status and enqueue per-file tasks."""
|
||||
task_logger.info("check_for_user_file_project_sync - Starting")
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
@@ -697,22 +646,16 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
queue_depth = get_user_file_project_sync_queue_depth(self.app)
|
||||
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"Queue depth {queue_depth} exceeds "
|
||||
f"{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}, skipping enqueue for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
select(UserFile.id).where(
|
||||
sa.and_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
sa.or_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.needs_persona_sync.is_(True),
|
||||
),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
@@ -722,23 +665,19 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
if not enqueue_user_file_project_sync_task(
|
||||
celery_app=self.app,
|
||||
redis_client=redis_client,
|
||||
user_file_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
):
|
||||
skipped_guard += 1
|
||||
continue
|
||||
)
|
||||
enqueued += 1
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"Enqueued {enqueued} "
|
||||
f"Skipped guard {skipped_guard} tasks for tenant={tenant_id}"
|
||||
f"check_for_user_file_project_sync - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -757,8 +696,6 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
@@ -772,7 +709,11 @@ def process_single_user_file_project_sync(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
user_file = db_session.execute(
|
||||
select(UserFile)
|
||||
.where(UserFile.id == _as_uuid(user_file_id))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
).scalar_one_or_none()
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
|
||||
@@ -800,13 +741,17 @@ def process_single_user_file_project_sync(
|
||||
]
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=project_ids,
|
||||
personas=persona_ids,
|
||||
),
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
@@ -814,6 +759,7 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
user_file.needs_persona_sync = False
|
||||
user_file.last_project_sync_at = datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
)
|
||||
|
||||
@@ -58,8 +58,6 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
|
||||
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
@@ -158,7 +156,36 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
logger.warning(
|
||||
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
|
||||
)
|
||||
cleaned_batch.append(sanitize_document_for_postgres(doc))
|
||||
cleaned_doc = doc.model_copy()
|
||||
|
||||
# Postgres cannot handle NUL characters in text fields
|
||||
if "\x00" in cleaned_doc.id:
|
||||
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
|
||||
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
|
||||
|
||||
if cleaned_doc.title and "\x00" in cleaned_doc.title:
|
||||
logger.warning(
|
||||
f"NUL characters found in document title: {cleaned_doc.title}"
|
||||
)
|
||||
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
|
||||
|
||||
if "\x00" in cleaned_doc.semantic_identifier:
|
||||
logger.warning(
|
||||
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
|
||||
)
|
||||
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link is not None:
|
||||
section.link = section.link.replace("\x00", "")
|
||||
|
||||
# since text can be longer, just replace to avoid double scan
|
||||
if isinstance(section, TextSection) and section.text is not None:
|
||||
section.text = section.text.replace("\x00", "")
|
||||
|
||||
cleaned_batch.append(cleaned_doc)
|
||||
|
||||
return cleaned_batch
|
||||
|
||||
@@ -575,13 +602,10 @@ def connector_document_extraction(
|
||||
|
||||
# Process hierarchy nodes batch - upsert to Postgres and cache in Redis
|
||||
if hierarchy_node_batch:
|
||||
hierarchy_node_batch_cleaned = (
|
||||
sanitize_hierarchy_nodes_for_postgres(hierarchy_node_batch)
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=hierarchy_node_batch_cleaned,
|
||||
nodes=hierarchy_node_batch,
|
||||
source=db_connector.source,
|
||||
commit=True,
|
||||
is_connector_public=is_connector_public,
|
||||
@@ -600,7 +624,7 @@ def connector_document_extraction(
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Persisted and cached {len(hierarchy_node_batch_cleaned)} hierarchy nodes "
|
||||
f"Persisted and cached {len(hierarchy_node_batch)} hierarchy nodes "
|
||||
f"for attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
|
||||
@@ -461,7 +461,7 @@ def _build_tool_call_response_history_message(
|
||||
def convert_chat_history(
|
||||
chat_history: list[ChatMessage],
|
||||
files: list[ChatLoadedFile],
|
||||
project_image_files: list[ChatLoadedFile],
|
||||
context_image_files: list[ChatLoadedFile],
|
||||
additional_context: str | None,
|
||||
token_counter: Callable[[str], int],
|
||||
tool_id_to_name_map: dict[int, str],
|
||||
@@ -541,11 +541,11 @@ def convert_chat_history(
|
||||
)
|
||||
|
||||
# Add the user message with image files attached
|
||||
# If this is the last USER message, also include project_image_files
|
||||
# Note: project image file tokens are NOT counted in the token count
|
||||
# If this is the last USER message, also include context_image_files
|
||||
# Note: context image file tokens are NOT counted in the token count
|
||||
if idx == last_user_message_idx:
|
||||
if project_image_files:
|
||||
image_files.extend(project_image_files)
|
||||
if context_image_files:
|
||||
image_files.extend(context_image_files)
|
||||
|
||||
if additional_context:
|
||||
simple_messages.append(
|
||||
|
||||
@@ -15,10 +15,10 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.chat.prompt_utils import build_reminder_message
|
||||
from onyx.chat.prompt_utils import build_system_prompt
|
||||
@@ -30,7 +30,6 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.memory import add_memory
|
||||
from onyx.db.memory import update_memory_at_index
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
@@ -203,17 +202,17 @@ def _try_fallback_tool_extraction(
|
||||
MAX_LLM_CYCLES = 6
|
||||
|
||||
|
||||
def _build_project_file_citation_mapping(
|
||||
project_file_metadata: list[ProjectFileMetadata],
|
||||
def _build_context_file_citation_mapping(
|
||||
file_metadata: list[ContextFileMetadata],
|
||||
starting_citation_num: int = 1,
|
||||
) -> CitationMapping:
|
||||
"""Build citation mapping for project files.
|
||||
"""Build citation mapping for context files.
|
||||
|
||||
Converts project file metadata into SearchDoc objects that can be cited.
|
||||
Converts context file metadata into SearchDoc objects that can be cited.
|
||||
Citation numbers start from the provided starting number.
|
||||
|
||||
Args:
|
||||
project_file_metadata: List of project file metadata
|
||||
file_metadata: List of context file metadata
|
||||
starting_citation_num: Starting citation number (default: 1)
|
||||
|
||||
Returns:
|
||||
@@ -221,8 +220,7 @@ def _build_project_file_citation_mapping(
|
||||
"""
|
||||
citation_mapping: CitationMapping = {}
|
||||
|
||||
for idx, file_meta in enumerate(project_file_metadata, start=starting_citation_num):
|
||||
# Create a SearchDoc for each project file
|
||||
for idx, file_meta in enumerate(file_metadata, start=starting_citation_num):
|
||||
search_doc = SearchDoc(
|
||||
document_id=file_meta.file_id,
|
||||
chunk_ind=0,
|
||||
@@ -242,22 +240,21 @@ def _build_project_file_citation_mapping(
|
||||
|
||||
|
||||
def _build_project_message(
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
project_files: ExtractedContextFiles | None,
|
||||
token_counter: Callable[[str], int] | None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Build messages for project / tool-backed files.
|
||||
"""Build messages for context-injected / tool-backed files.
|
||||
|
||||
Returns up to two messages:
|
||||
1. The full-text project files message (if project_file_texts is populated).
|
||||
1. The full-text files message (if file_texts is populated).
|
||||
2. A lightweight metadata message for files the LLM should access via the
|
||||
FileReaderTool (e.g. oversized chat-attached files or project files that
|
||||
don't fit in context).
|
||||
FileReaderTool (e.g. oversized files that don't fit in context).
|
||||
"""
|
||||
if not project_files:
|
||||
return []
|
||||
|
||||
messages: list[ChatMessageSimple] = []
|
||||
if project_files.project_file_texts:
|
||||
if project_files.file_texts:
|
||||
messages.append(
|
||||
_create_project_files_message(project_files, token_counter=None)
|
||||
)
|
||||
@@ -275,7 +272,7 @@ def construct_message_history(
|
||||
custom_agent_prompt: ChatMessageSimple | None,
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
reminder_message: ChatMessageSimple | None,
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
project_files: ExtractedContextFiles | None,
|
||||
available_tokens: int,
|
||||
last_n_user_messages: int | None = None,
|
||||
token_counter: Callable[[str], int] | None = None,
|
||||
@@ -445,13 +442,13 @@ def construct_message_history(
|
||||
)
|
||||
|
||||
# Attach project images to the last user message
|
||||
if project_files and project_files.project_image_files:
|
||||
if project_files and project_files.image_files:
|
||||
existing_images = last_user_message.image_files or []
|
||||
last_user_message = ChatMessageSimple(
|
||||
message=last_user_message.message,
|
||||
token_count=last_user_message.token_count,
|
||||
message_type=last_user_message.message_type,
|
||||
image_files=existing_images + project_files.project_image_files,
|
||||
image_files=existing_images + project_files.image_files,
|
||||
)
|
||||
|
||||
# Build the final message list according to README ordering:
|
||||
@@ -548,10 +545,10 @@ def _create_file_tool_metadata_message(
|
||||
|
||||
|
||||
def _create_project_files_message(
|
||||
project_files: ExtractedProjectFiles,
|
||||
project_files: ExtractedContextFiles,
|
||||
token_counter: Callable[[str], int] | None, # noqa: ARG001
|
||||
) -> ChatMessageSimple:
|
||||
"""Convert project files to a ChatMessageSimple message.
|
||||
"""Convert context files to a ChatMessageSimple message.
|
||||
|
||||
Format follows the README specification for document representation.
|
||||
"""
|
||||
@@ -559,7 +556,7 @@ def _create_project_files_message(
|
||||
|
||||
# Format as documents JSON as described in README
|
||||
documents_list = []
|
||||
for idx, file_text in enumerate(project_files.project_file_texts, start=1):
|
||||
for idx, file_text in enumerate(project_files.file_texts, start=1):
|
||||
documents_list.append(
|
||||
{
|
||||
"document": idx,
|
||||
@@ -584,7 +581,7 @@ def run_llm_loop(
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
tools: list[Tool],
|
||||
custom_agent_prompt: str | None,
|
||||
project_files: ExtractedProjectFiles,
|
||||
project_files: ExtractedContextFiles,
|
||||
persona: Persona | None,
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
@@ -627,9 +624,9 @@ def run_llm_loop(
|
||||
|
||||
# Add project file citation mappings if project files are present
|
||||
project_citation_mapping: CitationMapping = {}
|
||||
if project_files.project_file_metadata:
|
||||
project_citation_mapping = _build_project_file_citation_mapping(
|
||||
project_files.project_file_metadata
|
||||
if project_files.file_metadata:
|
||||
project_citation_mapping = _build_context_file_citation_mapping(
|
||||
project_files.file_metadata
|
||||
)
|
||||
citation_processor.update_citation_mapping(project_citation_mapping)
|
||||
|
||||
@@ -647,7 +644,7 @@ def run_llm_loop(
|
||||
# TODO allow citing of images in Projects. Since attached to the last user message, it has no text associated with it.
|
||||
# One future workaround is to include the images as separate user messages with citation information and process those.
|
||||
always_cite_documents: bool = bool(
|
||||
project_files.project_as_filter or project_files.project_file_texts
|
||||
project_files.use_as_search_filter or project_files.file_texts
|
||||
)
|
||||
should_cite_documents: bool = False
|
||||
ran_image_gen: bool = False
|
||||
@@ -657,12 +654,7 @@ def run_llm_loop(
|
||||
fallback_extraction_attempted: bool = False
|
||||
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
|
||||
|
||||
# Fetch this in a short-lived session so the long-running stream loop does
|
||||
# not pin a connection just to keep read state alive.
|
||||
with get_session_with_current_tenant() as prompt_db_session:
|
||||
default_base_system_prompt: str = get_default_base_system_prompt(
|
||||
prompt_db_session
|
||||
)
|
||||
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
|
||||
system_prompt = None
|
||||
custom_agent_prompt_msg = None
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
@@ -31,13 +30,6 @@ class CustomToolResponse(BaseModel):
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ProjectSearchConfig(BaseModel):
|
||||
"""Configuration for search tool availability in project context."""
|
||||
|
||||
search_usage: SearchToolUsage
|
||||
disable_forced_tool: bool
|
||||
|
||||
|
||||
class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
@@ -132,8 +124,8 @@ class ChatMessageSimple(BaseModel):
|
||||
file_id: str | None = None
|
||||
|
||||
|
||||
class ProjectFileMetadata(BaseModel):
|
||||
"""Metadata for a project file to enable citation support."""
|
||||
class ContextFileMetadata(BaseModel):
|
||||
"""Metadata for a context-injected file to enable citation support."""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
@@ -167,17 +159,17 @@ class ChatHistoryResult(BaseModel):
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata]
|
||||
|
||||
|
||||
class ExtractedProjectFiles(BaseModel):
|
||||
project_file_texts: list[str]
|
||||
project_image_files: list[ChatLoadedFile]
|
||||
project_as_filter: bool
|
||||
class ExtractedContextFiles(BaseModel):
|
||||
"""Result of attempting to load user files (from a project or persona) into context."""
|
||||
|
||||
file_texts: list[str]
|
||||
image_files: list[ChatLoadedFile]
|
||||
use_as_search_filter: bool
|
||||
total_token_count: int
|
||||
# Metadata for project files to enable citations
|
||||
project_file_metadata: list[ProjectFileMetadata]
|
||||
# None if not a project
|
||||
project_uncapped_token_count: int | None
|
||||
# Lightweight metadata for files exposed via FileReaderTool
|
||||
# (populated when files don't fit in context and vector DB is disabled)
|
||||
# (populated when files don't fit in context and vector DB is disabled).
|
||||
file_metadata: list[ContextFileMetadata]
|
||||
uncapped_token_count: int | None
|
||||
file_metadata_for_tool: list[FileToolMetadata] = []
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ IMPORTANT: familiarize yourself with the design concepts prior to contributing t
|
||||
An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import io
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
@@ -33,11 +34,10 @@ from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ProjectSearchConfig
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import ToolCallResponse
|
||||
from onyx.chat.prompt_utils import calculate_reserved_tokens
|
||||
@@ -62,11 +62,12 @@ from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import verify_user_files
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
@@ -192,9 +193,67 @@ def _convert_loaded_files_to_chat_files(
|
||||
return chat_files
|
||||
|
||||
|
||||
def _extract_project_file_texts_and_images(
|
||||
def _resolve_context_user_files(
|
||||
persona: Persona,
|
||||
project_id: int | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[UserFile]:
|
||||
"""Apply the precedence rule to decide which user files to load.
|
||||
|
||||
A custom persona fully supersedes the project. When a chat uses a
|
||||
custom persona, the project is purely organisational — its files are
|
||||
never loaded and never made searchable.
|
||||
|
||||
Custom persona → persona's own user_files (may be empty).
|
||||
Default persona inside a project → project files.
|
||||
Otherwise → empty list.
|
||||
"""
|
||||
if persona.id != DEFAULT_PERSONA_ID:
|
||||
return list(persona.user_files) if persona.user_files else []
|
||||
if project_id:
|
||||
return get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _empty_extracted_context_files() -> ExtractedContextFiles:
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=False,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=None,
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_in_memory_file(f: InMemoryChatFile) -> str | None:
|
||||
"""Extract text content from an InMemoryChatFile.
|
||||
|
||||
PLAIN_TEXT: the content is pre-extracted UTF-8 plaintext stored during
|
||||
ingestion — decode directly.
|
||||
DOC / CSV / other text types: the content is the original file bytes —
|
||||
use extract_file_text which handles encoding detection and format parsing.
|
||||
"""
|
||||
try:
|
||||
if f.file_type == ChatFileType.PLAIN_TEXT:
|
||||
return f.content.decode("utf-8").replace("\x00", "")
|
||||
return extract_file_text(
|
||||
file=io.BytesIO(f.content),
|
||||
file_name=f.filename or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to extract text from file {f.file_id}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _extract_context_files(
|
||||
user_files: list[UserFile],
|
||||
llm_max_context_window: int,
|
||||
reserved_token_count: int,
|
||||
db_session: Session,
|
||||
@@ -203,8 +262,12 @@ def _extract_project_file_texts_and_images(
|
||||
# 60% of the LLM's max context window. The other benefit is that for projects with
|
||||
# more files, this makes it so that we don't throw away the history too quickly every time.
|
||||
max_llm_context_percentage: float = 0.6,
|
||||
) -> ExtractedProjectFiles:
|
||||
"""Extract text content from project files if they fit within the context window.
|
||||
) -> ExtractedContextFiles:
|
||||
"""Load user files into context if they fit; otherwise flag for search.
|
||||
|
||||
The caller is responsible for deciding *which* user files to pass in
|
||||
(project files, persona files, etc.). This function only cares about
|
||||
the all-or-nothing fit check and the actual content loading.
|
||||
|
||||
Args:
|
||||
project_id: The project ID to load files from
|
||||
@@ -213,160 +276,95 @@ def _extract_project_file_texts_and_images(
|
||||
reserved_token_count: Number of tokens to reserve for other content
|
||||
db_session: Database session
|
||||
max_llm_context_percentage: Maximum percentage of the LLM context window to use.
|
||||
|
||||
Returns:
|
||||
ExtractedProjectFiles containing:
|
||||
- List of text content strings from project files (text files only)
|
||||
- List of image files from project (ChatLoadedFile objects)
|
||||
- Project id if the the project should be provided as a filter in search or None if not.
|
||||
ExtractedContextFiles containing:
|
||||
- List of text content strings from context files (text files only)
|
||||
- List of image files from context (ChatLoadedFile objects)
|
||||
- Total token count of all extracted files
|
||||
- File metadata for context files
|
||||
- Uncapped token count of all extracted files
|
||||
- File metadata for files that don't fit in context and vector DB is disabled
|
||||
"""
|
||||
# TODO I believe this is not handling all file types correctly.
|
||||
project_as_filter = False
|
||||
if not project_id:
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=None,
|
||||
)
|
||||
# TODO(yuhong): I believe this is not handling all file types correctly.
|
||||
|
||||
if not user_files:
|
||||
return _empty_extracted_context_files()
|
||||
|
||||
aggregate_tokens = sum(uf.token_count or 0 for uf in user_files)
|
||||
max_actual_tokens = (
|
||||
llm_max_context_window - reserved_token_count
|
||||
) * max_llm_context_percentage
|
||||
|
||||
# Calculate total token count for all user files in the project
|
||||
project_tokens = get_project_token_count(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
if aggregate_tokens >= max_actual_tokens:
|
||||
tool_metadata = []
|
||||
use_as_search_filter = not DISABLE_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB:
|
||||
tool_metadata = _build_file_tool_metadata_for_user_files(user_files)
|
||||
return ExtractedContextFiles(
|
||||
file_texts=[],
|
||||
image_files=[],
|
||||
use_as_search_filter=use_as_search_filter,
|
||||
total_token_count=0,
|
||||
file_metadata=[],
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
file_metadata_for_tool=tool_metadata,
|
||||
)
|
||||
|
||||
# Files fit — load them into context
|
||||
user_file_map = {str(uf.id): uf for uf in user_files}
|
||||
in_memory_files = load_in_memory_chat_files(
|
||||
user_file_ids=[uf.id for uf in user_files],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
project_file_texts: list[str] = []
|
||||
project_image_files: list[ChatLoadedFile] = []
|
||||
project_file_metadata: list[ProjectFileMetadata] = []
|
||||
file_texts: list[str] = []
|
||||
image_files: list[ChatLoadedFile] = []
|
||||
file_metadata: list[ContextFileMetadata] = []
|
||||
total_token_count = 0
|
||||
if project_tokens < max_actual_tokens:
|
||||
# Load project files into memory using cached plaintext when available
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if project_user_files:
|
||||
# Create a mapping from file_id to UserFile for token count lookup
|
||||
user_file_map = {str(file.id): file for file in project_user_files}
|
||||
|
||||
project_file_ids = [file.id for file in project_user_files]
|
||||
in_memory_project_files = load_in_memory_chat_files(
|
||||
user_file_ids=project_file_ids,
|
||||
db_session=db_session,
|
||||
for f in in_memory_files:
|
||||
uf = user_file_map.get(str(f.file_id))
|
||||
if f.file_type.is_text_file():
|
||||
text_content = _extract_text_from_in_memory_file(f)
|
||||
if not text_content:
|
||||
continue
|
||||
file_texts.append(text_content)
|
||||
file_metadata.append(
|
||||
ContextFileMetadata(
|
||||
file_id=str(f.file_id),
|
||||
filename=f.filename or f"file_{f.file_id}",
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
if uf and uf.token_count:
|
||||
total_token_count += uf.token_count
|
||||
elif f.file_type == ChatFileType.IMAGE:
|
||||
token_count = uf.token_count if uf and uf.token_count else 0
|
||||
total_token_count += token_count
|
||||
image_files.append(
|
||||
ChatLoadedFile(
|
||||
file_id=f.file_id,
|
||||
content=f.content,
|
||||
file_type=f.file_type,
|
||||
filename=f.filename,
|
||||
content_text=None,
|
||||
token_count=token_count,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract text content from loaded files
|
||||
for file in in_memory_project_files:
|
||||
if file.file_type.is_text_file():
|
||||
try:
|
||||
text_content = file.content.decode("utf-8", errors="ignore")
|
||||
# Strip null bytes
|
||||
text_content = text_content.replace("\x00", "")
|
||||
if text_content:
|
||||
project_file_texts.append(text_content)
|
||||
# Add metadata for citation support
|
||||
project_file_metadata.append(
|
||||
ProjectFileMetadata(
|
||||
file_id=str(file.file_id),
|
||||
filename=file.filename or f"file_{file.file_id}",
|
||||
file_content=text_content,
|
||||
)
|
||||
)
|
||||
# Add token count for text file
|
||||
user_file = user_file_map.get(str(file.file_id))
|
||||
if user_file and user_file.token_count:
|
||||
total_token_count += user_file.token_count
|
||||
except Exception:
|
||||
# Skip files that can't be decoded
|
||||
pass
|
||||
elif file.file_type == ChatFileType.IMAGE:
|
||||
# Convert InMemoryChatFile to ChatLoadedFile
|
||||
user_file = user_file_map.get(str(file.file_id))
|
||||
token_count = (
|
||||
user_file.token_count
|
||||
if user_file and user_file.token_count
|
||||
else 0
|
||||
)
|
||||
total_token_count += token_count
|
||||
chat_loaded_file = ChatLoadedFile(
|
||||
file_id=file.file_id,
|
||||
content=file.content,
|
||||
file_type=file.file_type,
|
||||
filename=file.filename,
|
||||
content_text=None, # Images don't have text content
|
||||
token_count=token_count,
|
||||
)
|
||||
project_image_files.append(chat_loaded_file)
|
||||
else:
|
||||
if DISABLE_VECTOR_DB:
|
||||
# Without a vector DB we can't use project-as-filter search.
|
||||
# Instead, build lightweight metadata so the LLM can call the
|
||||
# FileReaderTool to inspect individual files on demand.
|
||||
file_metadata_for_tool = _build_file_tool_metadata_for_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata_for_tool=file_metadata_for_tool,
|
||||
)
|
||||
project_as_filter = True
|
||||
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=project_file_texts,
|
||||
project_image_files=project_image_files,
|
||||
project_as_filter=project_as_filter,
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts,
|
||||
image_files=image_files,
|
||||
use_as_search_filter=False,
|
||||
total_token_count=total_token_count,
|
||||
project_file_metadata=project_file_metadata,
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata=file_metadata,
|
||||
uncapped_token_count=aggregate_tokens,
|
||||
)
|
||||
|
||||
|
||||
APPROX_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_project(
|
||||
project_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[FileToolMetadata]:
|
||||
"""Build lightweight FileToolMetadata for every file in a project.
|
||||
|
||||
Used when files are too large to fit in context and the vector DB is
|
||||
disabled, so the LLM needs to know which files it can read via the
|
||||
FileReaderTool.
|
||||
"""
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return [
|
||||
FileToolMetadata(
|
||||
file_id=str(uf.id),
|
||||
filename=uf.name,
|
||||
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
|
||||
)
|
||||
for uf in project_user_files
|
||||
]
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> list[FileToolMetadata]:
|
||||
@@ -381,58 +379,6 @@ def _build_file_tool_metadata_for_user_files(
|
||||
]
|
||||
|
||||
|
||||
def _get_project_search_availability(
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
loaded_project_files: bool,
|
||||
project_has_files: bool,
|
||||
forced_tool_id: int | None,
|
||||
search_tool_id: int | None,
|
||||
) -> ProjectSearchConfig:
|
||||
"""Determine search tool availability based on project context.
|
||||
|
||||
Search is disabled when ALL of the following are true:
|
||||
- User is in a project
|
||||
- Using the default persona (not a custom agent)
|
||||
- Project files are already loaded in context
|
||||
|
||||
When search is disabled and the user tried to force the search tool,
|
||||
that forcing is also disabled.
|
||||
|
||||
Returns AUTO (follow persona config) in all other cases.
|
||||
"""
|
||||
# Not in a project, this should have no impact on search tool availability
|
||||
if not project_id:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
|
||||
# Custom persona in project - let persona config decide
|
||||
# Even if there are no files in the project, it's still guided by the persona config.
|
||||
if persona_id != DEFAULT_PERSONA_ID:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
|
||||
# If in a project with the default persona and the files have been already loaded into the context or
|
||||
# there are no files in the project, disable search as there is nothing to search for.
|
||||
if loaded_project_files or not project_has_files:
|
||||
user_forced_search = (
|
||||
forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
)
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.DISABLED,
|
||||
disable_forced_tool=user_forced_search,
|
||||
)
|
||||
|
||||
# Default persona in a project with files, but also the files have not been loaded into the context already.
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.ENABLED, disable_forced_tool=False
|
||||
)
|
||||
|
||||
|
||||
def handle_stream_message_objects(
|
||||
new_msg_req: SendMessageRequest,
|
||||
user: User,
|
||||
@@ -661,26 +607,49 @@ def handle_stream_message_objects(
|
||||
user_memory_context=prompt_memory_context,
|
||||
)
|
||||
|
||||
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
|
||||
extracted_project_files = _extract_project_file_texts_and_images(
|
||||
# Determine which user files to use. A custom persona fully
|
||||
# supersedes the project — project files are never loaded or
|
||||
# searchable when a custom persona is in play. Only the default
|
||||
# persona inside a project uses the project's files.
|
||||
context_user_files = _resolve_context_user_files(
|
||||
persona=persona,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
extracted_context_files = _extract_context_files(
|
||||
user_files=context_user_files,
|
||||
llm_max_context_window=llm.config.max_input_tokens,
|
||||
reserved_token_count=reserved_token_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# When the vector DB is disabled, persona-attached user_files have no
|
||||
# search pipeline path. Inject them as file_metadata_for_tool so the
|
||||
# LLM can read them via the FileReaderTool.
|
||||
if DISABLE_VECTOR_DB and persona.user_files:
|
||||
persona_file_metadata = _build_file_tool_metadata_for_user_files(
|
||||
persona.user_files
|
||||
)
|
||||
# Merge persona file metadata into the extracted project files
|
||||
extracted_project_files.file_metadata_for_tool.extend(persona_file_metadata)
|
||||
# Figure out which search filter to pass. When all attached files
|
||||
# fit in context we omit the filter (content is already in prompt).
|
||||
# When they overflow, we pass the appropriate filter so the vector
|
||||
# DB scopes results to these files.
|
||||
#
|
||||
# A custom persona fully supersedes the project — project files are
|
||||
# never loaded, never searchable, and the search tool config is
|
||||
# entirely controlled by the persona. The project_id filter is
|
||||
# only set for the default persona.
|
||||
is_custom_persona = persona.id != DEFAULT_PERSONA_ID
|
||||
search_project_id: int | None = None
|
||||
search_persona_id: int | None = None
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
if is_custom_persona:
|
||||
search_persona_id = persona.id
|
||||
else:
|
||||
search_project_id = chat_session.project_id
|
||||
|
||||
# Also grant access to persona-attached user files for FileReaderTool
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Build a mapping of tool_id to tool_name for history reconstruction
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
@@ -689,30 +658,31 @@ def handle_stream_message_objects(
|
||||
None,
|
||||
)
|
||||
|
||||
# Determine if search should be disabled for this project context
|
||||
# Determine search forcing based on context file state.
|
||||
# Custom personas always get AUTO — their tool config is never
|
||||
# overridden. Only the default persona in a project gets special
|
||||
# treatment (ENABLED when files overflow, DISABLED when loaded or
|
||||
# absent).
|
||||
forced_tool_id = new_msg_req.forced_tool_id
|
||||
project_search_config = _get_project_search_availability(
|
||||
project_id=chat_session.project_id,
|
||||
persona_id=persona.id,
|
||||
loaded_project_files=bool(extracted_project_files.project_file_texts),
|
||||
project_has_files=bool(
|
||||
extracted_project_files.project_uncapped_token_count
|
||||
),
|
||||
forced_tool_id=new_msg_req.forced_tool_id,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
if project_search_config.disable_forced_tool:
|
||||
forced_tool_id = None
|
||||
search_usage = SearchToolUsage.AUTO
|
||||
|
||||
if not is_custom_persona and chat_session.project_id:
|
||||
has_context_files = bool(extracted_context_files.uncapped_token_count)
|
||||
files_loaded_in_context = bool(extracted_context_files.file_texts)
|
||||
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
search_usage = SearchToolUsage.ENABLED
|
||||
elif files_loaded_in_context or not has_context_files:
|
||||
search_usage = SearchToolUsage.DISABLED
|
||||
if (
|
||||
forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
):
|
||||
forced_tool_id = None
|
||||
|
||||
emitter = get_default_emitter()
|
||||
|
||||
# Also grant access to persona-attached user files
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Construct tools based on the persona configurations
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
@@ -722,11 +692,8 @@ def handle_stream_message_objects(
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id=(
|
||||
chat_session.project_id
|
||||
if extracted_project_files.project_as_filter
|
||||
else None
|
||||
),
|
||||
project_id=search_project_id,
|
||||
persona_id=search_persona_id,
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
@@ -744,7 +711,7 @@ def handle_stream_message_objects(
|
||||
chat_file_ids=available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=project_search_config.search_usage,
|
||||
search_usage_forcing_setting=search_usage,
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
@@ -783,7 +750,7 @@ def handle_stream_message_objects(
|
||||
chat_history_result = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
project_image_files=extracted_project_files.project_image_files,
|
||||
context_image_files=extracted_context_files.image_files,
|
||||
additional_context=additional_context,
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
@@ -856,11 +823,6 @@ def handle_stream_message_objects(
|
||||
reserved_tokens=reserved_token_count,
|
||||
)
|
||||
|
||||
# Release any read transaction before entering the long-running LLM stream.
|
||||
# Without this, the request-scoped session can keep a connection checked out
|
||||
# for the full stream duration.
|
||||
db_session.commit()
|
||||
|
||||
# The stream generator can resume on a different worker thread after early yields.
|
||||
# Set this right before launching the LLM loop so run_in_background copies the right context.
|
||||
if new_msg_req.mock_llm_response is not None:
|
||||
@@ -906,7 +868,7 @@ def handle_stream_message_objects(
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
project_files=extracted_context_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
|
||||
@@ -167,14 +167,6 @@ CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
# How long a queued user-file-project-sync task remains valid.
|
||||
# Should be short enough to discard stale queue entries under load while still
|
||||
# allowing workers enough time to pick up new tasks.
|
||||
CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Max queue depth before user-file-project-sync producers stop enqueuing.
|
||||
# This applies backpressure when workers are falling behind.
|
||||
USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
@@ -467,7 +459,6 @@ class OnyxRedisLocks:
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_PROJECT_SYNC_QUEUED_PREFIX = "da_lock:user_file_project_sync_queued"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
|
||||
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
"""Inverse mapping from user-facing Microsoft host URLs to the SDK's AzureEnvironment.
|
||||
|
||||
The office365 library's GraphClient requires an ``AzureEnvironment`` string
|
||||
(e.g. ``"Global"``, ``"GCC High"``) to route requests to the correct national
|
||||
cloud. Our connectors instead expose free-text ``authority_host`` and
|
||||
``graph_api_host`` fields so the frontend doesn't need to know about SDK
|
||||
internals.
|
||||
|
||||
This module bridges the gap: given the two host URLs the user configured, it
|
||||
resolves the matching ``AzureEnvironment`` value (and the implied SharePoint
|
||||
domain suffix) so callers can pass ``environment=…`` to ``GraphClient``.
|
||||
"""
|
||||
|
||||
from office365.graph_client import AzureEnvironment # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
|
||||
|
||||
class MicrosoftGraphEnvironment(BaseModel):
|
||||
"""One row of the inverse mapping."""
|
||||
|
||||
environment: str
|
||||
graph_host: str
|
||||
authority_host: str
|
||||
sharepoint_domain_suffix: str
|
||||
|
||||
|
||||
_ENVIRONMENTS: list[MicrosoftGraphEnvironment] = [
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.Global,
|
||||
graph_host="https://graph.microsoft.com",
|
||||
authority_host="https://login.microsoftonline.com",
|
||||
sharepoint_domain_suffix="sharepoint.com",
|
||||
),
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.USGovernmentHigh,
|
||||
graph_host="https://graph.microsoft.us",
|
||||
authority_host="https://login.microsoftonline.us",
|
||||
sharepoint_domain_suffix="sharepoint.us",
|
||||
),
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.USGovernmentDoD,
|
||||
graph_host="https://dod-graph.microsoft.us",
|
||||
authority_host="https://login.microsoftonline.us",
|
||||
sharepoint_domain_suffix="sharepoint.us",
|
||||
),
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.China,
|
||||
graph_host="https://microsoftgraph.chinacloudapi.cn",
|
||||
authority_host="https://login.chinacloudapi.cn",
|
||||
sharepoint_domain_suffix="sharepoint.cn",
|
||||
),
|
||||
MicrosoftGraphEnvironment(
|
||||
environment=AzureEnvironment.Germany,
|
||||
graph_host="https://graph.microsoft.de",
|
||||
authority_host="https://login.microsoftonline.de",
|
||||
sharepoint_domain_suffix="sharepoint.de",
|
||||
),
|
||||
]
|
||||
|
||||
_GRAPH_HOST_INDEX: dict[str, MicrosoftGraphEnvironment] = {
|
||||
env.graph_host: env for env in _ENVIRONMENTS
|
||||
}
|
||||
|
||||
|
||||
def resolve_microsoft_environment(
|
||||
graph_api_host: str,
|
||||
authority_host: str,
|
||||
) -> MicrosoftGraphEnvironment:
|
||||
"""Return the ``MicrosoftGraphEnvironment`` that matches the supplied hosts.
|
||||
|
||||
Raises ``ConnectorValidationError`` when the combination is unknown or
|
||||
internally inconsistent (e.g. a GCC-High graph host paired with a
|
||||
commercial authority host).
|
||||
"""
|
||||
graph_api_host = graph_api_host.rstrip("/")
|
||||
authority_host = authority_host.rstrip("/")
|
||||
|
||||
env = _GRAPH_HOST_INDEX.get(graph_api_host)
|
||||
if env is None:
|
||||
known = ", ".join(sorted(_GRAPH_HOST_INDEX))
|
||||
raise ConnectorValidationError(
|
||||
f"Unsupported Microsoft Graph API host '{graph_api_host}'. "
|
||||
f"Recognised hosts: {known}"
|
||||
)
|
||||
|
||||
if env.authority_host != authority_host:
|
||||
raise ConnectorValidationError(
|
||||
f"Authority host '{authority_host}' is inconsistent with "
|
||||
f"graph API host '{graph_api_host}'. "
|
||||
f"Expected authority host '{env.authority_host}' "
|
||||
f"for the {env.environment} environment."
|
||||
)
|
||||
|
||||
return env
|
||||
@@ -6,7 +6,6 @@ from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
@@ -168,14 +167,6 @@ class DocumentBase(BaseModel):
|
||||
# list of strings.
|
||||
metadata: dict[str, str | list[str]]
|
||||
|
||||
@field_validator("metadata", mode="before")
|
||||
@classmethod
|
||||
def _coerce_metadata_values(cls, v: dict[str, Any]) -> dict[str, str | list[str]]:
|
||||
return {
|
||||
key: [str(item) for item in val] if isinstance(val, list) else str(val)
|
||||
for key, val in v.items()
|
||||
}
|
||||
|
||||
# UTC time
|
||||
doc_updated_at: datetime | None = None
|
||||
chunk_count: int | None = None
|
||||
|
||||
@@ -47,7 +47,6 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import IndexingHeartbeatInterface
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
@@ -838,20 +837,10 @@ class SharepointConnector(
|
||||
self._cached_rest_ctx: ClientContext | None = None
|
||||
self._cached_rest_ctx_url: str | None = None
|
||||
self._cached_rest_ctx_created_at: float = 0.0
|
||||
|
||||
resolved_env = resolve_microsoft_environment(graph_api_host, authority_host)
|
||||
self._azure_environment = resolved_env.environment
|
||||
self.authority_host = resolved_env.authority_host
|
||||
self.graph_api_host = resolved_env.graph_host
|
||||
self.authority_host = authority_host.rstrip("/")
|
||||
self.graph_api_host = graph_api_host.rstrip("/")
|
||||
self.graph_api_base = f"{self.graph_api_host}/v1.0"
|
||||
self.sharepoint_domain_suffix = resolved_env.sharepoint_domain_suffix
|
||||
if sharepoint_domain_suffix != resolved_env.sharepoint_domain_suffix:
|
||||
logger.warning(
|
||||
f"Configured sharepoint_domain_suffix '{sharepoint_domain_suffix}' "
|
||||
f"differs from the expected suffix '{resolved_env.sharepoint_domain_suffix}' "
|
||||
f"for the {resolved_env.environment} environment. "
|
||||
f"Using '{resolved_env.sharepoint_domain_suffix}'."
|
||||
)
|
||||
self.sharepoint_domain_suffix = sharepoint_domain_suffix
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
# Validate that at least one content type is enabled
|
||||
@@ -1603,7 +1592,6 @@ class SharepointConnector(
|
||||
if certificate_data is None:
|
||||
raise RuntimeError("Failed to load certificate")
|
||||
|
||||
logger.info(f"Creating MSAL app with authority url {authority_url}")
|
||||
self.msal_app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=sp_client_id,
|
||||
@@ -1635,9 +1623,7 @@ class SharepointConnector(
|
||||
raise ConnectorValidationError("Failed to acquire token for graph")
|
||||
return token
|
||||
|
||||
self._graph_client = GraphClient(
|
||||
_acquire_token_for_graph, environment=self._azure_environment
|
||||
)
|
||||
self._graph_client = GraphClient(_acquire_token_for_graph)
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
org = self.graph_client.organization.get().execute_query()
|
||||
if not org or len(org) == 0:
|
||||
|
||||
@@ -23,7 +23,6 @@ from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -74,11 +73,8 @@ class TeamsConnector(
|
||||
self.msal_app: msal.ConfidentialClientApplication | None = None
|
||||
self.max_workers = max_workers
|
||||
self.requested_team_list: list[str] = teams
|
||||
|
||||
resolved_env = resolve_microsoft_environment(graph_api_host, authority_host)
|
||||
self._azure_environment = resolved_env.environment
|
||||
self.authority_host = resolved_env.authority_host
|
||||
self.graph_api_host = resolved_env.graph_host
|
||||
self.authority_host = authority_host.rstrip("/")
|
||||
self.graph_api_host = graph_api_host.rstrip("/")
|
||||
|
||||
# impls for BaseConnector
|
||||
|
||||
@@ -110,9 +106,7 @@ class TeamsConnector(
|
||||
|
||||
return token
|
||||
|
||||
self.graph_client = GraphClient(
|
||||
_acquire_token_func, environment=self._azure_environment
|
||||
)
|
||||
self.graph_client = GraphClient(_acquire_token_func)
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
|
||||
@@ -72,6 +72,7 @@ class BaseFilters(BaseModel):
|
||||
class UserFileFilters(BaseModel):
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
|
||||
class AssistantKnowledgeFilters(BaseModel):
|
||||
|
||||
@@ -40,6 +40,7 @@ def _build_index_filters(
|
||||
user_provided_filters: BaseFilters | None,
|
||||
user: User, # Used for ACLs, anonymous users only see public docs
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
user_file_ids: list[UUID] | None,
|
||||
persona_document_sets: list[str] | None,
|
||||
persona_time_cutoff: datetime | None,
|
||||
@@ -59,11 +60,12 @@ def _build_index_filters(
|
||||
|
||||
base_filters = user_provided_filters or BaseFilters()
|
||||
|
||||
document_set_filter = (
|
||||
base_filters.document_set
|
||||
if base_filters.document_set is not None
|
||||
else persona_document_sets
|
||||
)
|
||||
if (
|
||||
user_provided_filters
|
||||
and user_provided_filters.document_set is None
|
||||
and persona_document_sets is not None
|
||||
):
|
||||
base_filters.document_set = persona_document_sets
|
||||
|
||||
time_filter = base_filters.time_cutoff or persona_time_cutoff
|
||||
source_filter = base_filters.source_type
|
||||
@@ -118,8 +120,9 @@ def _build_index_filters(
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
source_type=source_filter,
|
||||
document_set=document_set_filter,
|
||||
document_set=persona_document_sets,
|
||||
time_cutoff=time_filter,
|
||||
tags=base_filters.tags,
|
||||
access_control_list=user_acl_filters,
|
||||
@@ -265,6 +268,8 @@ def search_pipeline(
|
||||
llm: LLM | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
# If a persona_id is provided, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB queries (no session needed)
|
||||
acl_filters: list[str] | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
@@ -299,6 +304,7 @@ def search_pipeline(
|
||||
user_provided_filters=chunk_search_request.user_selected_filters,
|
||||
user=user,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
user_file_ids=user_uploaded_persona_files,
|
||||
persona_document_sets=persona_document_sets,
|
||||
persona_time_cutoff=persona_time_cutoff,
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import CodeInterpreterServer
|
||||
|
||||
|
||||
def fetch_code_interpreter_server(
|
||||
db_session: Session,
|
||||
) -> CodeInterpreterServer:
|
||||
server = db_session.scalars(select(CodeInterpreterServer)).one()
|
||||
return server
|
||||
|
||||
|
||||
def update_code_interpreter_server_enabled(
|
||||
db_session: Session,
|
||||
enabled: bool,
|
||||
) -> CodeInterpreterServer:
|
||||
server = db_session.scalars(select(CodeInterpreterServer)).one()
|
||||
server.server_enabled = enabled
|
||||
db_session.commit()
|
||||
return server
|
||||
@@ -1,102 +1,11 @@
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
|
||||
def get_schemas_needing_migration(
|
||||
tenant_schemas: list[str], head_rev: str
|
||||
) -> list[str]:
|
||||
"""Return only schemas whose current alembic version is not at head.
|
||||
|
||||
Uses a server-side PL/pgSQL loop to collect each schema's alembic version
|
||||
into a temp table one at a time. This avoids building a massive UNION ALL
|
||||
query (which locks the DB and times out at 17k+ schemas) and instead
|
||||
acquires locks sequentially, one schema per iteration.
|
||||
"""
|
||||
if not tenant_schemas:
|
||||
return []
|
||||
|
||||
engine = SqlEngine.get_engine()
|
||||
|
||||
with engine.connect() as conn:
|
||||
# Populate a temp input table with exactly the schemas we care about.
|
||||
# The DO block reads from this table so it only iterates the requested
|
||||
# schemas instead of every tenant_% schema in the database.
|
||||
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
|
||||
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
|
||||
conn.execute(text("CREATE TEMP TABLE _tenant_schemas_input (schema_name text)"))
|
||||
conn.execute(
|
||||
text(
|
||||
"INSERT INTO _tenant_schemas_input (schema_name) "
|
||||
"SELECT unnest(CAST(:schemas AS text[]))"
|
||||
),
|
||||
{"schemas": tenant_schemas},
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"CREATE TEMP TABLE _alembic_version_snapshot "
|
||||
"(schema_name text, version_num text)"
|
||||
)
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE
|
||||
s text;
|
||||
schemas text[];
|
||||
BEGIN
|
||||
SELECT array_agg(schema_name) INTO schemas
|
||||
FROM _tenant_schemas_input;
|
||||
|
||||
IF schemas IS NULL THEN
|
||||
RAISE NOTICE 'No tenant schemas found.';
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
FOREACH s IN ARRAY schemas LOOP
|
||||
BEGIN
|
||||
EXECUTE format(
|
||||
'INSERT INTO _alembic_version_snapshot
|
||||
SELECT %L, version_num FROM %I.alembic_version',
|
||||
s, s
|
||||
);
|
||||
EXCEPTION
|
||||
-- undefined_table: schema exists but has no alembic_version
|
||||
-- table yet (new tenant, not yet migrated).
|
||||
-- invalid_schema_name: tenant is registered but its
|
||||
-- PostgreSQL schema does not exist yet (e.g. provisioning
|
||||
-- incomplete). Both cases mean no version is available and
|
||||
-- the schema will be included in the migration list.
|
||||
WHEN undefined_table THEN NULL;
|
||||
WHEN invalid_schema_name THEN NULL;
|
||||
END;
|
||||
END LOOP;
|
||||
END;
|
||||
$$
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
rows = conn.execute(
|
||||
text("SELECT schema_name, version_num FROM _alembic_version_snapshot")
|
||||
)
|
||||
version_by_schema = {row[0]: row[1] for row in rows}
|
||||
|
||||
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
|
||||
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
|
||||
|
||||
# Schemas missing from the snapshot have no alembic_version table yet and
|
||||
# also need migration. version_by_schema.get(s) returns None for those,
|
||||
# and None != head_rev, so they are included automatically.
|
||||
return [s for s in tenant_schemas if version_by_schema.get(s) != head_rev]
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str]:
|
||||
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
|
||||
|
||||
|
||||
@@ -4270,6 +4270,9 @@ class UserFile(Base):
|
||||
needs_project_sync: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
needs_persona_sync: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
last_project_sync_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
@@ -4940,11 +4943,6 @@ class ScimUserMapping(Base):
|
||||
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
|
||||
)
|
||||
scim_username: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
department: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
manager: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
given_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
family_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
scim_emails_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
|
||||
@@ -765,6 +765,9 @@ def mark_persona_as_deleted(
|
||||
) -> None:
|
||||
persona = get_persona_by_id(persona_id=persona_id, user=user, db_session=db_session)
|
||||
persona.deleted = True
|
||||
affected_file_ids = [uf.id for uf in persona.user_files]
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, affected_file_ids)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -776,11 +779,13 @@ def mark_persona_as_not_deleted(
|
||||
persona = get_persona_by_id(
|
||||
persona_id=persona_id, user=user, db_session=db_session, include_deleted=True
|
||||
)
|
||||
if persona.deleted:
|
||||
persona.deleted = False
|
||||
db_session.commit()
|
||||
else:
|
||||
if not persona.deleted:
|
||||
raise ValueError(f"Persona with ID {persona_id} is not deleted.")
|
||||
persona.deleted = False
|
||||
affected_file_ids = [uf.id for uf in persona.user_files]
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, affected_file_ids)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_delete_persona_by_name(
|
||||
@@ -846,6 +851,20 @@ def update_personas_display_priority(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _mark_files_need_persona_sync(
|
||||
db_session: Session,
|
||||
user_file_ids: list[UUID],
|
||||
) -> None:
|
||||
"""Flag the given UserFile rows so the background sync task picks them up
|
||||
and updates their persona metadata in the vector DB."""
|
||||
if not user_file_ids:
|
||||
return
|
||||
db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).update(
|
||||
{UserFile.needs_persona_sync: True},
|
||||
synchronize_session=False,
|
||||
)
|
||||
|
||||
|
||||
def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
@@ -1034,8 +1053,13 @@ def upsert_persona(
|
||||
existing_persona.tools = tools or []
|
||||
|
||||
if user_file_ids is not None:
|
||||
old_file_ids = {uf.id for uf in existing_persona.user_files}
|
||||
new_file_ids = {uf.id for uf in (user_files or [])}
|
||||
affected_file_ids = old_file_ids | new_file_ids
|
||||
existing_persona.user_files.clear()
|
||||
existing_persona.user_files = user_files or []
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, list(affected_file_ids))
|
||||
|
||||
if hierarchy_node_ids is not None:
|
||||
existing_persona.hierarchy_nodes.clear()
|
||||
@@ -1089,6 +1113,8 @@ def upsert_persona(
|
||||
attached_documents=attached_documents or [],
|
||||
)
|
||||
db_session.add(new_persona)
|
||||
if user_files:
|
||||
_mark_files_need_persona_sync(db_session, [uf.id for uf in user_files])
|
||||
persona = new_persona
|
||||
if commit:
|
||||
db_session.commit()
|
||||
|
||||
@@ -64,6 +64,19 @@ def fetch_user_project_ids_for_user_files(
|
||||
}
|
||||
|
||||
|
||||
def fetch_persona_ids_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, list[int]]:
|
||||
"""Fetch persona (assistant) ids for specified user files."""
|
||||
stmt = select(UserFile).where(UserFile.id.in_(user_file_ids))
|
||||
results = db_session.execute(stmt).scalars().all()
|
||||
return {
|
||||
str(user_file.id): [persona.id for persona in user_file.assistants]
|
||||
for user_file in results
|
||||
}
|
||||
|
||||
|
||||
def update_last_accessed_at_for_user_files(
|
||||
user_file_ids: list[UUID],
|
||||
db_session: Session,
|
||||
|
||||
@@ -121,6 +121,7 @@ class VespaDocumentUserFields:
|
||||
"""
|
||||
|
||||
user_projects: list[int] | None = None
|
||||
personas: list[int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -148,6 +148,7 @@ class MetadataUpdateRequest(BaseModel):
|
||||
hidden: bool | None = None
|
||||
secondary_index_updated: bool | None = None
|
||||
project_ids: set[int] | None = None
|
||||
persona_ids: set[int] | None = None
|
||||
|
||||
|
||||
class IndexRetrievalFilters(BaseModel):
|
||||
|
||||
@@ -50,6 +50,7 @@ from onyx.document_index.opensearch.schema import DocumentSchema
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.search import DocumentQuery
|
||||
from onyx.document_index.opensearch.search import (
|
||||
@@ -215,6 +216,7 @@ def _convert_onyx_chunk_to_opensearch_document(
|
||||
# OpenSearch and it will not store any data at all for this field, which
|
||||
# is different from supplying an empty list.
|
||||
user_projects=chunk.user_project or None,
|
||||
personas=chunk.personas or None,
|
||||
primary_owners=get_experts_stores_representations(
|
||||
chunk.source_document.primary_owners
|
||||
),
|
||||
@@ -362,6 +364,11 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
if user_fields and user_fields.user_projects
|
||||
else None
|
||||
),
|
||||
persona_ids=(
|
||||
set(user_fields.personas)
|
||||
if user_fields and user_fields.personas
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -709,6 +716,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
properties_to_update[USER_PROJECTS_FIELD_NAME] = list(
|
||||
update_request.project_ids
|
||||
)
|
||||
if update_request.persona_ids is not None:
|
||||
properties_to_update[PERSONAS_FIELD_NAME] = list(
|
||||
update_request.persona_ids
|
||||
)
|
||||
|
||||
if not properties_to_update:
|
||||
if len(update_request.document_ids) > 1:
|
||||
|
||||
@@ -41,6 +41,7 @@ IMAGE_FILE_ID_FIELD_NAME = "image_file_id"
|
||||
SOURCE_LINKS_FIELD_NAME = "source_links"
|
||||
DOCUMENT_SETS_FIELD_NAME = "document_sets"
|
||||
USER_PROJECTS_FIELD_NAME = "user_projects"
|
||||
PERSONAS_FIELD_NAME = "personas"
|
||||
DOCUMENT_ID_FIELD_NAME = "document_id"
|
||||
CHUNK_INDEX_FIELD_NAME = "chunk_index"
|
||||
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
|
||||
@@ -156,6 +157,7 @@ class DocumentChunk(BaseModel):
|
||||
|
||||
document_sets: list[str] | None = None
|
||||
user_projects: list[int] | None = None
|
||||
personas: list[int] | None = None
|
||||
primary_owners: list[str] | None = None
|
||||
secondary_owners: list[str] | None = None
|
||||
|
||||
@@ -485,6 +487,7 @@ class DocumentSchema:
|
||||
# Product-specific fields.
|
||||
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
|
||||
USER_PROJECTS_FIELD_NAME: {"type": "integer"},
|
||||
PERSONAS_FIELD_NAME: {"type": "integer"},
|
||||
PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
# OpenSearch metadata fields.
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc
|
||||
from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME
|
||||
@@ -144,6 +145,7 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=min_chunk_index,
|
||||
max_chunk_index=max_chunk_index,
|
||||
@@ -202,6 +204,7 @@ class DocumentQuery:
|
||||
document_sets=[],
|
||||
user_file_ids=[],
|
||||
project_id=None,
|
||||
persona_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -267,6 +270,7 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -334,6 +338,7 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -496,6 +501,7 @@ class DocumentQuery:
|
||||
document_sets: list[str],
|
||||
user_file_ids: list[UUID],
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
time_cutoff: datetime | None,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
@@ -530,6 +536,8 @@ class DocumentQuery:
|
||||
retrieved.
|
||||
project_id: If not None, only documents with this project ID in user
|
||||
projects will be retrieved.
|
||||
persona_id: If not None, only documents whose personas array
|
||||
contains this persona ID will be retrieved.
|
||||
time_cutoff: Time cutoff for the documents to retrieve. If not None,
|
||||
Documents which were last updated before this date will not be
|
||||
returned. For documents which do not have a value for their last
|
||||
@@ -627,6 +635,9 @@ class DocumentQuery:
|
||||
)
|
||||
return user_project_filter
|
||||
|
||||
def _get_persona_filter(persona_id: int) -> dict[str, Any]:
|
||||
return {"term": {PERSONAS_FIELD_NAME: {"value": persona_id}}}
|
||||
|
||||
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
|
||||
# Convert to UTC if not already so the cutoff is comparable to the
|
||||
# document data.
|
||||
@@ -780,6 +791,9 @@ class DocumentQuery:
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if persona_id is not None:
|
||||
filter_clauses.append(_get_persona_filter(persona_id))
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
# documents where the document was last updated at or after the time
|
||||
|
||||
@@ -181,6 +181,11 @@ schema {{ schema_name }} {
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
field personas type array<int> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
}
|
||||
|
||||
# If using different tokenization settings, the fieldset has to be removed, and the field must
|
||||
|
||||
@@ -689,6 +689,9 @@ class VespaIndex(DocumentIndex):
|
||||
project_ids: set[int] | None = None
|
||||
if user_fields is not None and user_fields.user_projects is not None:
|
||||
project_ids = set(user_fields.user_projects)
|
||||
persona_ids: set[int] | None = None
|
||||
if user_fields is not None and user_fields.personas is not None:
|
||||
persona_ids = set(user_fields.personas)
|
||||
update_request = MetadataUpdateRequest(
|
||||
document_ids=[doc_id],
|
||||
doc_id_to_chunk_cnt={
|
||||
@@ -699,6 +702,7 @@ class VespaIndex(DocumentIndex):
|
||||
boost=fields.boost if fields is not None else None,
|
||||
hidden=fields.hidden if fields is not None else None,
|
||||
project_ids=project_ids,
|
||||
persona_ids=persona_ids,
|
||||
)
|
||||
|
||||
vespa_document_index.update([update_request])
|
||||
|
||||
@@ -46,6 +46,7 @@ from onyx.document_index.vespa_constants import METADATA
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import METADATA_SUFFIX
|
||||
from onyx.document_index.vespa_constants import NUM_THREADS
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SECTION_CONTINUATION
|
||||
@@ -218,6 +219,7 @@ def _index_vespa_chunk(
|
||||
# still called `image_file_name` in Vespa for backwards compatibility
|
||||
IMAGE_FILE_NAME: chunk.image_file_id,
|
||||
USER_PROJECT: chunk.user_project if chunk.user_project is not None else [],
|
||||
PERSONAS: chunk.personas if chunk.personas is not None else [],
|
||||
BOOST: chunk.boost,
|
||||
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
@@ -149,6 +150,18 @@ def build_vespa_filters(
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
|
||||
def _build_persona_filter(
|
||||
persona_id: int | None,
|
||||
) -> str:
|
||||
if persona_id is None:
|
||||
return ""
|
||||
try:
|
||||
pid = int(persona_id)
|
||||
except Exception:
|
||||
logger.warning(f"Invalid persona ID: {persona_id}")
|
||||
return ""
|
||||
return f'({PERSONAS} contains "{pid}") and '
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
@@ -192,6 +205,9 @@ def build_vespa_filters(
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
|
||||
# Persona filter (array<int> attribute membership)
|
||||
filter_str += _build_persona_filter(filters.persona_id)
|
||||
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
|
||||
@@ -183,6 +183,10 @@ def _update_single_chunk(
|
||||
model_config = {"frozen": True}
|
||||
assign: list[int]
|
||||
|
||||
class _Personas(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
assign: list[int]
|
||||
|
||||
class _VespaPutFields(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
# The names of these fields are based the Vespa schema. Changes to the
|
||||
@@ -193,6 +197,7 @@ def _update_single_chunk(
|
||||
access_control_list: _AccessControl | None = None
|
||||
hidden: _Hidden | None = None
|
||||
user_project: _UserProjects | None = None
|
||||
personas: _Personas | None = None
|
||||
|
||||
class _VespaPutRequest(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
@@ -227,6 +232,11 @@ def _update_single_chunk(
|
||||
if update_request.project_ids is not None
|
||||
else None
|
||||
)
|
||||
personas_update: _Personas | None = (
|
||||
_Personas(assign=list(update_request.persona_ids))
|
||||
if update_request.persona_ids is not None
|
||||
else None
|
||||
)
|
||||
|
||||
vespa_put_fields = _VespaPutFields(
|
||||
boost=boost_update,
|
||||
@@ -234,6 +244,7 @@ def _update_single_chunk(
|
||||
access_control_list=access_update,
|
||||
hidden=hidden_update,
|
||||
user_project=user_projects_update,
|
||||
personas=personas_update,
|
||||
)
|
||||
|
||||
vespa_put_request = _VespaPutRequest(
|
||||
@@ -554,9 +565,10 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
vespa_where_clauses = build_vespa_filters(filters)
|
||||
# Avoid over-fetching a very large candidate set for global-phase reranking.
|
||||
# Keep enough headroom for quality while capping cost on larger indices.
|
||||
target_hits = min(max(4 * num_to_retrieve, 100), RERANK_COUNT)
|
||||
# Needs to be at least as much as the rerank-count value set in the
|
||||
# Vespa schema config. Otherwise we would be getting fewer results than
|
||||
# expected for reranking.
|
||||
target_hits = max(10 * num_to_retrieve, RERANK_COUNT)
|
||||
|
||||
yql = (
|
||||
YQL_BASE.format(index_name=self._index_name)
|
||||
|
||||
@@ -58,6 +58,7 @@ DOCUMENT_SETS = "document_sets"
|
||||
USER_FILE = "user_file"
|
||||
USER_FOLDER = "user_folder"
|
||||
USER_PROJECT = "user_project"
|
||||
PERSONAS = "personas"
|
||||
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
|
||||
METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
|
||||
@@ -12,9 +12,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class AzureImageGenerationProvider(ImageGenerationProvider):
|
||||
_GPT_IMAGE_MODEL_PREFIX = "gpt-image-"
|
||||
_DALL_E_2_MODEL_NAME = "dall-e-2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
@@ -56,25 +53,6 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
|
||||
deployment_name=credentials.deployment_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_reference_images(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def max_reference_images(self) -> int:
|
||||
# Azure GPT image models support up to 16 input images for edits.
|
||||
return 16
|
||||
|
||||
def _normalize_model_name(self, model: str) -> str:
|
||||
return model.rsplit("/", 1)[-1]
|
||||
|
||||
def _model_supports_image_edits(self, model: str) -> bool:
|
||||
normalized_model = self._normalize_model_name(model)
|
||||
return (
|
||||
normalized_model.startswith(self._GPT_IMAGE_MODEL_PREFIX)
|
||||
or normalized_model == self._DALL_E_2_MODEL_NAME
|
||||
)
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -82,44 +60,14 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
from litellm import image_generation
|
||||
|
||||
deployment = self._deployment_name or model
|
||||
model_name = f"azure/{deployment}"
|
||||
|
||||
if reference_images:
|
||||
if not self._model_supports_image_edits(model):
|
||||
raise ValueError(
|
||||
f"Model '{model}' does not support image edits with reference images."
|
||||
)
|
||||
|
||||
normalized_model = self._normalize_model_name(model)
|
||||
if (
|
||||
normalized_model == self._DALL_E_2_MODEL_NAME
|
||||
and len(reference_images) > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Model 'dall-e-2' only supports a single reference image for edits."
|
||||
)
|
||||
|
||||
from litellm import image_edit
|
||||
|
||||
return image_edit(
|
||||
image=[image.data for image in reference_images],
|
||||
prompt=prompt,
|
||||
model=model_name,
|
||||
api_key=self._api_key,
|
||||
api_base=self._api_base,
|
||||
api_version=self._api_version,
|
||||
size=size,
|
||||
n=n,
|
||||
quality=quality,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
from litellm import image_generation
|
||||
|
||||
return image_generation(
|
||||
prompt=prompt,
|
||||
model=model_name,
|
||||
|
||||
@@ -12,9 +12,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class OpenAIImageGenerationProvider(ImageGenerationProvider):
|
||||
_GPT_IMAGE_MODEL_PREFIX = "gpt-image-"
|
||||
_DALL_E_2_MODEL_NAME = "dall-e-2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
@@ -42,25 +39,6 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
|
||||
api_base=credentials.api_base,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_reference_images(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def max_reference_images(self) -> int:
|
||||
# GPT image models support up to 16 input images for edits.
|
||||
return 16
|
||||
|
||||
def _normalize_model_name(self, model: str) -> str:
|
||||
return model.rsplit("/", 1)[-1]
|
||||
|
||||
def _model_supports_image_edits(self, model: str) -> bool:
|
||||
normalized_model = self._normalize_model_name(model)
|
||||
return (
|
||||
normalized_model.startswith(self._GPT_IMAGE_MODEL_PREFIX)
|
||||
or normalized_model == self._DALL_E_2_MODEL_NAME
|
||||
)
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -68,38 +46,9 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
if reference_images:
|
||||
if not self._model_supports_image_edits(model):
|
||||
raise ValueError(
|
||||
f"Model '{model}' does not support image edits with reference images."
|
||||
)
|
||||
|
||||
normalized_model = self._normalize_model_name(model)
|
||||
if (
|
||||
normalized_model == self._DALL_E_2_MODEL_NAME
|
||||
and len(reference_images) > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Model 'dall-e-2' only supports a single reference image for edits."
|
||||
)
|
||||
|
||||
from litellm import image_edit
|
||||
|
||||
return image_edit(
|
||||
image=[image.data for image in reference_images],
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
api_key=self._api_key,
|
||||
api_base=self._api_base,
|
||||
size=size,
|
||||
n=n,
|
||||
quality=quality,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
from litellm import image_generation
|
||||
|
||||
return image_generation(
|
||||
|
||||
@@ -146,6 +146,7 @@ class DocumentIndexingBatchAdapter:
|
||||
doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=(
|
||||
context.id_to_boost_map[chunk.source_document.id]
|
||||
if chunk.source_document.id in context.id_to_boost_map
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.user_file import fetch_chunk_counts_for_user_files
|
||||
from onyx.db.user_file import fetch_persona_ids_for_user_files
|
||||
from onyx.db.user_file import fetch_user_project_ids_for_user_files
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
@@ -119,6 +120,10 @@ class UserFileIndexingAdapter:
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
user_file_id_to_persona_ids = fetch_persona_ids_for_user_files(
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
user_file_id_to_access: dict[str, DocumentAccess] = get_access_for_user_files(
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
@@ -182,7 +187,7 @@ class UserFileIndexingAdapter:
|
||||
user_project=user_file_id_to_project_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
# we are going to index userfiles only once, so we just set the boost to the default
|
||||
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
|
||||
boost=DEFAULT_BOOST,
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
|
||||
@@ -49,7 +49,6 @@ from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import IndexingBatchAdapter
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.factory import get_llm_for_contextual_rag
|
||||
@@ -229,8 +228,6 @@ def index_doc_batch_prepare(
|
||||
) -> DocumentBatchPrepareContext | None:
|
||||
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
|
||||
This preceeds indexing it into the actual document index."""
|
||||
documents = sanitize_documents_for_postgres(documents)
|
||||
|
||||
# Create a trimmed list of docs that don't have a newer updated at
|
||||
# Shortcuts the time-consuming flow on connector index retries
|
||||
document_ids: list[str] = [document.id for document in documents]
|
||||
|
||||
@@ -112,6 +112,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
user_project: list[int]
|
||||
personas: list[int]
|
||||
boost: int
|
||||
aggregated_chunk_boost_factor: float
|
||||
# Full ancestor path from root hierarchy node to document's parent.
|
||||
@@ -126,6 +127,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess",
|
||||
document_sets: set[str],
|
||||
user_project: list[int],
|
||||
personas: list[int],
|
||||
boost: int,
|
||||
aggregated_chunk_boost_factor: float,
|
||||
tenant_id: str,
|
||||
@@ -137,6 +139,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
user_project=user_project,
|
||||
personas=personas,
|
||||
boost=boost,
|
||||
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -1,150 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
|
||||
|
||||
def _sanitize_string(value: str) -> str:
|
||||
return value.replace("\x00", "")
|
||||
|
||||
|
||||
def _sanitize_json_like(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return _sanitize_string(value)
|
||||
|
||||
if isinstance(value, list):
|
||||
return [_sanitize_json_like(item) for item in value]
|
||||
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_sanitize_json_like(item) for item in value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
sanitized: dict[Any, Any] = {}
|
||||
for key, nested_value in value.items():
|
||||
cleaned_key = _sanitize_string(key) if isinstance(key, str) else key
|
||||
sanitized[cleaned_key] = _sanitize_json_like(nested_value)
|
||||
return sanitized
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _sanitize_expert_info(expert: BasicExpertInfo) -> BasicExpertInfo:
|
||||
return expert.model_copy(
|
||||
update={
|
||||
"display_name": (
|
||||
_sanitize_string(expert.display_name)
|
||||
if expert.display_name is not None
|
||||
else None
|
||||
),
|
||||
"first_name": (
|
||||
_sanitize_string(expert.first_name)
|
||||
if expert.first_name is not None
|
||||
else None
|
||||
),
|
||||
"middle_initial": (
|
||||
_sanitize_string(expert.middle_initial)
|
||||
if expert.middle_initial is not None
|
||||
else None
|
||||
),
|
||||
"last_name": (
|
||||
_sanitize_string(expert.last_name)
|
||||
if expert.last_name is not None
|
||||
else None
|
||||
),
|
||||
"email": (
|
||||
_sanitize_string(expert.email) if expert.email is not None else None
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_external_access(external_access: ExternalAccess) -> ExternalAccess:
|
||||
return ExternalAccess(
|
||||
external_user_emails={
|
||||
_sanitize_string(email) for email in external_access.external_user_emails
|
||||
},
|
||||
external_user_group_ids={
|
||||
_sanitize_string(group_id)
|
||||
for group_id in external_access.external_user_group_ids
|
||||
},
|
||||
is_public=external_access.is_public,
|
||||
)
|
||||
|
||||
|
||||
def sanitize_document_for_postgres(document: Document) -> Document:
|
||||
cleaned_doc = document.model_copy(deep=True)
|
||||
|
||||
cleaned_doc.id = _sanitize_string(cleaned_doc.id)
|
||||
cleaned_doc.semantic_identifier = _sanitize_string(cleaned_doc.semantic_identifier)
|
||||
if cleaned_doc.title is not None:
|
||||
cleaned_doc.title = _sanitize_string(cleaned_doc.title)
|
||||
if cleaned_doc.parent_hierarchy_raw_node_id is not None:
|
||||
cleaned_doc.parent_hierarchy_raw_node_id = _sanitize_string(
|
||||
cleaned_doc.parent_hierarchy_raw_node_id
|
||||
)
|
||||
|
||||
cleaned_doc.metadata = {
|
||||
_sanitize_string(key): (
|
||||
[_sanitize_string(item) for item in value]
|
||||
if isinstance(value, list)
|
||||
else _sanitize_string(value)
|
||||
)
|
||||
for key, value in cleaned_doc.metadata.items()
|
||||
}
|
||||
|
||||
if cleaned_doc.doc_metadata is not None:
|
||||
cleaned_doc.doc_metadata = _sanitize_json_like(cleaned_doc.doc_metadata)
|
||||
|
||||
if cleaned_doc.primary_owners is not None:
|
||||
cleaned_doc.primary_owners = [
|
||||
_sanitize_expert_info(expert) for expert in cleaned_doc.primary_owners
|
||||
]
|
||||
if cleaned_doc.secondary_owners is not None:
|
||||
cleaned_doc.secondary_owners = [
|
||||
_sanitize_expert_info(expert) for expert in cleaned_doc.secondary_owners
|
||||
]
|
||||
|
||||
if cleaned_doc.external_access is not None:
|
||||
cleaned_doc.external_access = _sanitize_external_access(
|
||||
cleaned_doc.external_access
|
||||
)
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link is not None:
|
||||
section.link = _sanitize_string(section.link)
|
||||
if section.text is not None:
|
||||
section.text = _sanitize_string(section.text)
|
||||
if section.image_file_id is not None:
|
||||
section.image_file_id = _sanitize_string(section.image_file_id)
|
||||
|
||||
return cleaned_doc
|
||||
|
||||
|
||||
def sanitize_documents_for_postgres(documents: list[Document]) -> list[Document]:
|
||||
return [sanitize_document_for_postgres(document) for document in documents]
|
||||
|
||||
|
||||
def sanitize_hierarchy_node_for_postgres(node: HierarchyNode) -> HierarchyNode:
|
||||
cleaned_node = node.model_copy(deep=True)
|
||||
|
||||
cleaned_node.raw_node_id = _sanitize_string(cleaned_node.raw_node_id)
|
||||
cleaned_node.display_name = _sanitize_string(cleaned_node.display_name)
|
||||
if cleaned_node.raw_parent_id is not None:
|
||||
cleaned_node.raw_parent_id = _sanitize_string(cleaned_node.raw_parent_id)
|
||||
if cleaned_node.link is not None:
|
||||
cleaned_node.link = _sanitize_string(cleaned_node.link)
|
||||
|
||||
if cleaned_node.external_access is not None:
|
||||
cleaned_node.external_access = _sanitize_external_access(
|
||||
cleaned_node.external_access
|
||||
)
|
||||
|
||||
return cleaned_node
|
||||
|
||||
|
||||
def sanitize_hierarchy_nodes_for_postgres(
|
||||
nodes: list[HierarchyNode],
|
||||
) -> list[HierarchyNode]:
|
||||
return [sanitize_hierarchy_node_for_postgres(node) for node in nodes]
|
||||
@@ -97,9 +97,6 @@ from onyx.server.features.web_search.api import router as web_search_router
|
||||
from onyx.server.federated.api import router as federated_router
|
||||
from onyx.server.kg.api import admin_router as kg_admin_router
|
||||
from onyx.server.manage.administrative import router as admin_router
|
||||
from onyx.server.manage.code_interpreter.api import (
|
||||
admin_router as code_interpreter_admin_router,
|
||||
)
|
||||
from onyx.server.manage.discord_bot.api import router as discord_bot_router
|
||||
from onyx.server.manage.embedding.api import admin_router as embedding_admin_router
|
||||
from onyx.server.manage.embedding.api import basic_router as embedding_router
|
||||
@@ -424,9 +421,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, llm_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, kg_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, llm_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, code_interpreter_admin_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, image_generation_admin_router
|
||||
)
|
||||
|
||||
@@ -1,68 +1,14 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
_CITATION_LINK_PATTERN = re.compile(r"\[\[\d+\]\]\(")
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
"""Extract markdown link destination, allowing nested parentheses in the URL."""
|
||||
depth = 0
|
||||
i = start_idx
|
||||
|
||||
while i < len(message):
|
||||
curr = message[i]
|
||||
if curr == "\\":
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if curr == "(":
|
||||
depth += 1
|
||||
elif curr == ")":
|
||||
if depth == 0:
|
||||
return message[start_idx:i], i
|
||||
depth -= 1
|
||||
i += 1
|
||||
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_citation_link_destinations(message: str) -> str:
|
||||
"""Wrap citation URLs in angle brackets so markdown parsers handle parentheses safely."""
|
||||
if "[[" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _CITATION_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
if end_idx is None:
|
||||
normalized_parts.append(message[destination_start:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
already_wrapped = destination.startswith("<") and destination.endswith(">")
|
||||
if destination and not already_wrapped:
|
||||
destination = f"<{destination}>"
|
||||
|
||||
normalized_parts.append(destination)
|
||||
normalized_parts.append(")")
|
||||
cursor = end_idx + 1
|
||||
|
||||
normalized_parts.append(message[cursor:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
if message is None:
|
||||
return ""
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
normalized_message = _normalize_citation_link_destinations(message)
|
||||
result = md(normalized_message)
|
||||
result = md(message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result
|
||||
|
||||
@@ -762,43 +762,6 @@ def download_webapp(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{session_id}/download-directory/{path:path}")
|
||||
def download_directory(
|
||||
session_id: UUID,
|
||||
path: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""
|
||||
Download a directory as a zip file.
|
||||
|
||||
Returns the specified directory as a zip archive.
|
||||
"""
|
||||
user_id: UUID = user.id
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
try:
|
||||
result = session_manager.download_directory(session_id, user_id, path)
|
||||
except ValueError as e:
|
||||
error_message = str(e)
|
||||
if "path traversal" in error_message.lower():
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
raise HTTPException(status_code=400, detail=error_message)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail="Directory not found")
|
||||
|
||||
zip_bytes, filename = result
|
||||
|
||||
return Response(
|
||||
content=zip_bytes,
|
||||
media_type="application/zip",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{session_id}/upload", response_model=UploadResponse)
|
||||
def upload_file_endpoint(
|
||||
session_id: UUID,
|
||||
|
||||
@@ -107,23 +107,27 @@ def get_or_create_craft_connector(db_session: Session, user: User) -> tuple[int,
|
||||
)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if (
|
||||
cc_pair.connector.source == DocumentSource.CRAFT_FILE
|
||||
and cc_pair.creator_id == user.id
|
||||
):
|
||||
if cc_pair.connector.source == DocumentSource.CRAFT_FILE:
|
||||
return cc_pair.connector.id, cc_pair.credential.id
|
||||
|
||||
# No cc_pair for this user — find or create the shared CRAFT_FILE connector
|
||||
# Check for orphaned connector (created but cc_pair creation failed previously)
|
||||
existing_connectors = fetch_connectors(
|
||||
db_session, sources=[DocumentSource.CRAFT_FILE]
|
||||
)
|
||||
connector_id: int | None = None
|
||||
orphaned_connector = None
|
||||
for conn in existing_connectors:
|
||||
if conn.name == USER_LIBRARY_CONNECTOR_NAME:
|
||||
connector_id = conn.id
|
||||
if conn.name != USER_LIBRARY_CONNECTOR_NAME:
|
||||
continue
|
||||
if not conn.credentials:
|
||||
orphaned_connector = conn
|
||||
break
|
||||
|
||||
if connector_id is None:
|
||||
if orphaned_connector:
|
||||
connector_id = orphaned_connector.id
|
||||
logger.info(
|
||||
f"Found orphaned User Library connector {connector_id}, completing setup"
|
||||
)
|
||||
else:
|
||||
connector_data = ConnectorBase(
|
||||
name=USER_LIBRARY_CONNECTOR_NAME,
|
||||
source=DocumentSource.CRAFT_FILE,
|
||||
|
||||
Binary file not shown.
@@ -68,7 +68,6 @@ from onyx.server.features.build.db.sandbox import create_sandbox__no_commit
|
||||
from onyx.server.features.build.db.sandbox import get_running_sandbox_count_by_tenant
|
||||
from onyx.server.features.build.db.sandbox import get_sandbox_by_session_id
|
||||
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
|
||||
from onyx.server.features.build.db.sandbox import get_snapshots_for_session
|
||||
from onyx.server.features.build.db.sandbox import update_sandbox_heartbeat
|
||||
from onyx.server.features.build.db.sandbox import update_sandbox_status__no_commit
|
||||
from onyx.server.features.build.sandbox import get_sandbox_manager
|
||||
@@ -647,30 +646,16 @@ class SessionManager:
|
||||
|
||||
if sandbox and sandbox.status.is_active():
|
||||
# Quick health check to verify sandbox is actually responsive
|
||||
# AND verify the session workspace still exists on disk
|
||||
# (it may have been wiped if the sandbox was re-provisioned)
|
||||
is_healthy = self._sandbox_manager.health_check(sandbox.id, timeout=5.0)
|
||||
workspace_exists = (
|
||||
is_healthy
|
||||
and self._sandbox_manager.session_workspace_exists(
|
||||
sandbox.id, existing.id
|
||||
)
|
||||
)
|
||||
if is_healthy and workspace_exists:
|
||||
if self._sandbox_manager.health_check(sandbox.id, timeout=5.0):
|
||||
logger.info(
|
||||
f"Returning existing empty session {existing.id} for user {user_id}"
|
||||
)
|
||||
return existing
|
||||
elif not is_healthy:
|
||||
else:
|
||||
logger.warning(
|
||||
f"Empty session {existing.id} has unhealthy sandbox {sandbox.id}. "
|
||||
f"Deleting and creating fresh session."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Empty session {existing.id} workspace missing in sandbox "
|
||||
f"{sandbox.id}. Deleting and creating fresh session."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Empty session {existing.id} has no active sandbox "
|
||||
@@ -1050,23 +1035,6 @@ class SessionManager:
|
||||
# workspace cleanup fails (e.g., if pod is already terminated)
|
||||
logger.warning(f"Failed to cleanup session workspace {session_id}: {e}")
|
||||
|
||||
# Delete snapshot files from S3 before removing DB records
|
||||
snapshots = get_snapshots_for_session(self._db_session, session_id)
|
||||
if snapshots:
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.features.build.sandbox.manager.snapshot_manager import (
|
||||
SnapshotManager,
|
||||
)
|
||||
|
||||
snapshot_manager = SnapshotManager(get_default_file_store())
|
||||
for snapshot in snapshots:
|
||||
try:
|
||||
snapshot_manager.delete_snapshot(snapshot.storage_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to delete snapshot file {snapshot.storage_path}: {e}"
|
||||
)
|
||||
|
||||
# Delete session (uses flush, caller commits)
|
||||
return delete_build_session__no_commit(session_id, user_id, self._db_session)
|
||||
|
||||
@@ -1935,94 +1903,6 @@ class SessionManager:
|
||||
|
||||
return zip_buffer.getvalue(), filename
|
||||
|
||||
def download_directory(
|
||||
self,
|
||||
session_id: UUID,
|
||||
user_id: UUID,
|
||||
path: str,
|
||||
) -> tuple[bytes, str] | None:
|
||||
"""
|
||||
Create a zip file of an arbitrary directory in the session workspace.
|
||||
|
||||
Args:
|
||||
session_id: The session UUID
|
||||
user_id: The user ID to verify ownership
|
||||
path: Relative path to the directory (within session workspace)
|
||||
|
||||
Returns:
|
||||
Tuple of (zip_bytes, filename) or None if session not found
|
||||
|
||||
Raises:
|
||||
ValueError: If path traversal attempted or path is not a directory
|
||||
"""
|
||||
# Verify session ownership
|
||||
session = get_build_session(session_id, user_id, self._db_session)
|
||||
if session is None:
|
||||
return None
|
||||
|
||||
sandbox = get_sandbox_by_user_id(self._db_session, user_id)
|
||||
if sandbox is None:
|
||||
return None
|
||||
|
||||
# Check if directory exists
|
||||
try:
|
||||
self._sandbox_manager.list_directory(
|
||||
sandbox_id=sandbox.id,
|
||||
session_id=session_id,
|
||||
path=path,
|
||||
)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# Recursively collect all files
|
||||
def collect_files(dir_path: str) -> list[tuple[str, str]]:
|
||||
"""Collect all files recursively, returning (full_path, arcname) tuples."""
|
||||
files: list[tuple[str, str]] = []
|
||||
try:
|
||||
entries = self._sandbox_manager.list_directory(
|
||||
sandbox_id=sandbox.id,
|
||||
session_id=session_id,
|
||||
path=dir_path,
|
||||
)
|
||||
for entry in entries:
|
||||
if entry.is_directory:
|
||||
files.extend(collect_files(entry.path))
|
||||
else:
|
||||
# arcname is relative to the target directory
|
||||
prefix_len = len(path) + 1 # +1 for trailing slash
|
||||
arcname = entry.path[prefix_len:]
|
||||
files.append((entry.path, arcname))
|
||||
except ValueError:
|
||||
pass
|
||||
return files
|
||||
|
||||
file_list = collect_files(path)
|
||||
|
||||
# Create zip file in memory
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
|
||||
for full_path, arcname in file_list:
|
||||
try:
|
||||
content = self._sandbox_manager.read_file(
|
||||
sandbox_id=sandbox.id,
|
||||
session_id=session_id,
|
||||
path=full_path,
|
||||
)
|
||||
zip_file.writestr(arcname, content)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
zip_buffer.seek(0)
|
||||
|
||||
# Use the directory name for the zip filename
|
||||
dir_name = Path(path).name
|
||||
safe_name = "".join(
|
||||
c if c.isalnum() or c in ("-", "_", ".") else "_" for c in dir_name
|
||||
)
|
||||
filename = f"{safe_name}.zip"
|
||||
|
||||
return zip_buffer.getvalue(), filename
|
||||
|
||||
# =========================================================================
|
||||
# File System Operations
|
||||
# =========================================================================
|
||||
@@ -2057,18 +1937,11 @@ class SessionManager:
|
||||
return None
|
||||
|
||||
# Use sandbox manager to list directory (works for both local and K8s)
|
||||
# If the directory doesn't exist (e.g., session workspace not yet loaded),
|
||||
# return an empty listing rather than erroring out.
|
||||
try:
|
||||
raw_entries = self._sandbox_manager.list_directory(
|
||||
sandbox_id=sandbox.id,
|
||||
session_id=session_id,
|
||||
path=path,
|
||||
)
|
||||
except ValueError as e:
|
||||
if "path traversal" in str(e).lower():
|
||||
raise
|
||||
return DirectoryListing(path=path, entries=[])
|
||||
raw_entries = self._sandbox_manager.list_directory(
|
||||
sandbox_id=sandbox.id,
|
||||
session_id=session_id,
|
||||
path=path,
|
||||
)
|
||||
|
||||
# Filter hidden files and directories
|
||||
entries: list[FileSystemEntry] = [
|
||||
|
||||
@@ -12,18 +12,11 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import ChatSession
|
||||
@@ -34,7 +27,6 @@ from onyx.db.models import UserProject
|
||||
from onyx.db.persona import get_personas_by_ids
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import upload_files_to_user_files_with_indexing
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.projects.models import CategorizedFilesSnapshot
|
||||
from onyx.server.features.projects.models import ChatSessionRequest
|
||||
from onyx.server.features.projects.models import TokenCountResponse
|
||||
@@ -55,33 +47,6 @@ class UserFileDeleteResult(BaseModel):
|
||||
assistant_names: list[str] = []
|
||||
|
||||
|
||||
def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None:
|
||||
queue_depth = get_user_file_project_sync_queue_depth(client_app)
|
||||
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
|
||||
logger.warning(
|
||||
f"Skipping immediate project sync for user_file_id={user_file_id} due to "
|
||||
f"queue depth {queue_depth}>{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}. "
|
||||
"It will be picked up by beat later."
|
||||
)
|
||||
return
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
enqueued = enqueue_user_file_project_sync_task(
|
||||
celery_app=client_app,
|
||||
redis_client=redis_client,
|
||||
user_file_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
priority=OnyxCeleryPriority.HIGHEST,
|
||||
)
|
||||
if not enqueued:
|
||||
logger.info(
|
||||
f"Skipped duplicate project sync enqueue for user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"Triggered project sync for user_file_id={user_file_id}")
|
||||
|
||||
|
||||
@router.get("", tags=PUBLIC_API_TAGS)
|
||||
def get_projects(
|
||||
user: User = Depends(current_user),
|
||||
@@ -224,7 +189,15 @@ def unlink_user_file_from_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=OnyxCeleryPriority.HIGHEST,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -268,7 +241,15 @@ def link_user_file_to_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=OnyxCeleryPriority.HIGHEST,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
|
||||
return UserFileSnapshot.from_model(user_file)
|
||||
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.code_interpreter import fetch_code_interpreter_server
|
||||
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.manage.code_interpreter.models import CodeInterpreterServer
|
||||
from onyx.server.manage.code_interpreter.models import CodeInterpreterServerHealth
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/code-interpreter")
|
||||
|
||||
|
||||
@admin_router.get("/health")
|
||||
def get_code_interpreter_health(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> CodeInterpreterServerHealth:
|
||||
try:
|
||||
client = CodeInterpreterClient()
|
||||
return CodeInterpreterServerHealth(healthy=client.health())
|
||||
except ValueError:
|
||||
return CodeInterpreterServerHealth(healthy=False)
|
||||
|
||||
|
||||
@admin_router.get("")
|
||||
def get_code_interpreter(
|
||||
_: User = Depends(current_admin_user), db_session: Session = Depends(get_session)
|
||||
) -> CodeInterpreterServer:
|
||||
ci_server = fetch_code_interpreter_server(db_session)
|
||||
return CodeInterpreterServer(enabled=ci_server.server_enabled)
|
||||
|
||||
|
||||
@admin_router.put("")
|
||||
def update_code_interpreter(
|
||||
update: CodeInterpreterServer,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_code_interpreter_server_enabled(
|
||||
db_session=db_session,
|
||||
enabled=update.enabled,
|
||||
)
|
||||
@@ -1,9 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeInterpreterServer(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class CodeInterpreterServerHealth(BaseModel):
|
||||
healthy: bool
|
||||
@@ -587,7 +587,6 @@ def handle_send_chat_message(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
additional_context=chat_message_req.additional_context,
|
||||
external_state_container=state_container,
|
||||
)
|
||||
result = gather_stream_full(packets, state_container)
|
||||
@@ -610,7 +609,6 @@ def handle_send_chat_message(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
additional_context=chat_message_req.additional_context,
|
||||
external_state_container=state_container,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
@@ -125,11 +125,6 @@ class SendMessageRequest(BaseModel):
|
||||
# - No CitationInfo packets are emitted during streaming
|
||||
include_citations: bool = True
|
||||
|
||||
# Additional context injected into the LLM call but NOT stored in the DB
|
||||
# (not shown in chat history). Used e.g. by the Chrome extension to pass
|
||||
# the current tab URL when "Read this tab" is enabled.
|
||||
additional_context: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_chat_session_id_or_info(self) -> "SendMessageRequest":
|
||||
# If neither is provided, default to creating a new chat session using the
|
||||
|
||||
@@ -54,6 +54,7 @@ logger = setup_logger()
|
||||
class SearchToolConfig(BaseModel):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
bypass_acl: bool = False
|
||||
additional_context: str | None = None
|
||||
slack_context: SlackContext | None = None
|
||||
@@ -180,6 +181,7 @@ def construct_tools(
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
@@ -427,6 +429,7 @@ def construct_tools(
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Literal
|
||||
from typing import TypedDict
|
||||
from typing import Union
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
@@ -39,39 +36,6 @@ class ExecuteResponse(BaseModel):
|
||||
files: list[WorkspaceFile]
|
||||
|
||||
|
||||
class StreamOutputEvent(BaseModel):
|
||||
"""SSE 'output' event: a chunk of stdout or stderr"""
|
||||
|
||||
stream: Literal["stdout", "stderr"]
|
||||
data: str
|
||||
|
||||
|
||||
class StreamResultEvent(BaseModel):
|
||||
"""SSE 'result' event: final execution result"""
|
||||
|
||||
exit_code: int | None
|
||||
timed_out: bool
|
||||
duration_ms: int
|
||||
files: list[WorkspaceFile]
|
||||
|
||||
|
||||
class StreamErrorEvent(BaseModel):
|
||||
"""SSE 'error' event: execution-level error"""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
StreamEvent = Union[StreamOutputEvent, StreamResultEvent, StreamErrorEvent]
|
||||
|
||||
_SSE_EVENT_MAP: dict[
|
||||
str, type[StreamOutputEvent | StreamResultEvent | StreamErrorEvent]
|
||||
] = {
|
||||
"output": StreamOutputEvent,
|
||||
"result": StreamResultEvent,
|
||||
"error": StreamErrorEvent,
|
||||
}
|
||||
|
||||
|
||||
class CodeInterpreterClient:
|
||||
"""Client for Code Interpreter service"""
|
||||
|
||||
@@ -81,34 +45,6 @@ class CodeInterpreterClient:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.session = requests.Session()
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
code: str,
|
||||
stdin: str | None,
|
||||
timeout_ms: int,
|
||||
files: list[FileInput] | None,
|
||||
) -> dict:
|
||||
payload: dict = {
|
||||
"code": code,
|
||||
"timeout_ms": timeout_ms,
|
||||
}
|
||||
if stdin is not None:
|
||||
payload["stdin"] = stdin
|
||||
if files:
|
||||
payload["files"] = files
|
||||
return payload
|
||||
|
||||
def health(self) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy"""
|
||||
url = f"{self.base_url}/health"
|
||||
try:
|
||||
response = self.session.get(url, timeout=5)
|
||||
response.raise_for_status()
|
||||
return response.json().get("status") == "ok"
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception caught when checking health, e={e}")
|
||||
return False
|
||||
|
||||
def execute(
|
||||
self,
|
||||
code: str,
|
||||
@@ -116,110 +52,25 @@ class CodeInterpreterClient:
|
||||
timeout_ms: int = 30000,
|
||||
files: list[FileInput] | None = None,
|
||||
) -> ExecuteResponse:
|
||||
"""Execute Python code (batch)"""
|
||||
"""Execute Python code"""
|
||||
url = f"{self.base_url}/v1/execute"
|
||||
payload = self._build_payload(code, stdin, timeout_ms, files)
|
||||
|
||||
payload = {
|
||||
"code": code,
|
||||
"timeout_ms": timeout_ms,
|
||||
}
|
||||
|
||||
if stdin is not None:
|
||||
payload["stdin"] = stdin
|
||||
|
||||
if files:
|
||||
payload["files"] = files
|
||||
|
||||
response = self.session.post(url, json=payload, timeout=timeout_ms / 1000 + 10)
|
||||
response.raise_for_status()
|
||||
|
||||
return ExecuteResponse(**response.json())
|
||||
|
||||
def execute_streaming(
|
||||
self,
|
||||
code: str,
|
||||
stdin: str | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
files: list[FileInput] | None = None,
|
||||
) -> Generator[StreamEvent, None, None]:
|
||||
"""Execute Python code with streaming SSE output.
|
||||
|
||||
Yields StreamEvent objects (StreamOutputEvent, StreamResultEvent,
|
||||
StreamErrorEvent) as execution progresses. Falls back to batch
|
||||
execution if the streaming endpoint is not available (older
|
||||
code-interpreter versions).
|
||||
"""
|
||||
url = f"{self.base_url}/v1/execute/stream"
|
||||
payload = self._build_payload(code, stdin, timeout_ms, files)
|
||||
|
||||
response = self.session.post(
|
||||
url,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=timeout_ms / 1000 + 10,
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
logger.info(
|
||||
"Streaming endpoint not available, " "falling back to batch execution"
|
||||
)
|
||||
response.close()
|
||||
yield from self._batch_as_stream(code, stdin, timeout_ms, files)
|
||||
return
|
||||
|
||||
response.raise_for_status()
|
||||
yield from self._parse_sse(response)
|
||||
|
||||
def _parse_sse(
|
||||
self, response: requests.Response
|
||||
) -> Generator[StreamEvent, None, None]:
|
||||
"""Parse SSE streaming response into StreamEvent objects.
|
||||
|
||||
Expected format per event:
|
||||
event: <type>
|
||||
data: <json>
|
||||
<blank line>
|
||||
"""
|
||||
event_type: str | None = None
|
||||
data_lines: list[str] = []
|
||||
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if line is None:
|
||||
continue
|
||||
|
||||
if line == "":
|
||||
# Blank line marks end of an SSE event
|
||||
if event_type is not None and data_lines:
|
||||
data = "\n".join(data_lines)
|
||||
model_cls = _SSE_EVENT_MAP.get(event_type)
|
||||
if model_cls is not None:
|
||||
yield model_cls(**json.loads(data))
|
||||
else:
|
||||
logger.warning(f"Unknown SSE event type: {event_type}")
|
||||
event_type = None
|
||||
data_lines = []
|
||||
elif line.startswith("event:"):
|
||||
event_type = line[len("event:") :].strip()
|
||||
elif line.startswith("data:"):
|
||||
data_lines.append(line[len("data:") :].strip())
|
||||
|
||||
if event_type is not None or data_lines:
|
||||
logger.warning(
|
||||
f"SSE stream ended with incomplete event: "
|
||||
f"event_type={event_type}, data_lines={data_lines}"
|
||||
)
|
||||
|
||||
def _batch_as_stream(
|
||||
self,
|
||||
code: str,
|
||||
stdin: str | None,
|
||||
timeout_ms: int,
|
||||
files: list[FileInput] | None,
|
||||
) -> Generator[StreamEvent, None, None]:
|
||||
"""Execute via batch endpoint and yield results as stream events."""
|
||||
result = self.execute(code, stdin, timeout_ms, files)
|
||||
|
||||
if result.stdout:
|
||||
yield StreamOutputEvent(stream="stdout", data=result.stdout)
|
||||
if result.stderr:
|
||||
yield StreamOutputEvent(stream="stderr", data=result.stderr)
|
||||
yield StreamResultEvent(
|
||||
exit_code=result.exit_code,
|
||||
timed_out=result.timed_out,
|
||||
duration_ms=result.duration_ms,
|
||||
files=result.files,
|
||||
)
|
||||
|
||||
def upload_file(self, file_content: bytes, filename: str) -> str:
|
||||
"""Upload file to Code Interpreter and return file_id"""
|
||||
url = f"{self.base_url}/v1/files"
|
||||
|
||||
@@ -12,7 +12,6 @@ from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_DEFAULT_TIMEOUT_MS
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_MAX_OUTPUT_LENGTH
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.code_interpreter import fetch_code_interpreter_server
|
||||
from onyx.file_store.utils import build_full_frontend_file_url
|
||||
from onyx.file_store.utils import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -29,15 +28,6 @@ from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import FileInput
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
StreamErrorEvent,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
StreamOutputEvent,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
StreamResultEvent,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -104,10 +94,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
@override
|
||||
@classmethod
|
||||
def is_available(cls, db_session: Session) -> bool:
|
||||
if not CODE_INTERPRETER_BASE_URL:
|
||||
return False
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
return server.server_enabled
|
||||
is_available = bool(CODE_INTERPRETER_BASE_URL)
|
||||
return is_available
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
@@ -193,50 +181,19 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
try:
|
||||
logger.debug(f"Executing code: {code}")
|
||||
|
||||
# Execute code with streaming (falls back to batch if unavailable)
|
||||
stdout_parts: list[str] = []
|
||||
stderr_parts: list[str] = []
|
||||
result_event: StreamResultEvent | None = None
|
||||
|
||||
for event in client.execute_streaming(
|
||||
# Execute code with timeout
|
||||
response = client.execute(
|
||||
code=code,
|
||||
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
|
||||
files=files_to_stage or None,
|
||||
):
|
||||
if isinstance(event, StreamOutputEvent):
|
||||
if event.stream == "stdout":
|
||||
stdout_parts.append(event.data)
|
||||
else:
|
||||
stderr_parts.append(event.data)
|
||||
# Emit incremental delta to frontend
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout=event.data if event.stream == "stdout" else "",
|
||||
stderr=event.data if event.stream == "stderr" else "",
|
||||
),
|
||||
)
|
||||
)
|
||||
elif isinstance(event, StreamResultEvent):
|
||||
result_event = event
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
raise RuntimeError(f"Code interpreter error: {event.message}")
|
||||
|
||||
if result_event is None:
|
||||
raise RuntimeError(
|
||||
"Code interpreter stream ended without a result event"
|
||||
)
|
||||
|
||||
full_stdout = "".join(stdout_parts)
|
||||
full_stderr = "".join(stderr_parts)
|
||||
)
|
||||
|
||||
# Truncate output for LLM consumption
|
||||
truncated_stdout = _truncate_output(
|
||||
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
|
||||
response.stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
|
||||
)
|
||||
truncated_stderr = _truncate_output(
|
||||
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
|
||||
response.stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
|
||||
)
|
||||
|
||||
# Handle generated files
|
||||
@@ -245,7 +202,7 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
file_ids_to_cleanup: list[str] = []
|
||||
file_store = get_default_file_store()
|
||||
|
||||
for workspace_file in result_event.files:
|
||||
for workspace_file in response.files:
|
||||
if workspace_file.kind != "file" or not workspace_file.file_id:
|
||||
continue
|
||||
|
||||
@@ -301,23 +258,26 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
|
||||
)
|
||||
|
||||
# Emit file_ids once files are processed
|
||||
if generated_file_ids:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(file_ids=generated_file_ids),
|
||||
)
|
||||
# Emit delta with stdout/stderr and generated files
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
file_ids=generated_file_ids,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Build result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
exit_code=result_event.exit_code,
|
||||
timed_out=result_event.timed_out,
|
||||
exit_code=response.exit_code,
|
||||
timed_out=response.timed_out,
|
||||
generated_files=generated_files,
|
||||
error=None if result_event.exit_code == 0 else truncated_stderr,
|
||||
error=None if response.exit_code == 0 else truncated_stderr,
|
||||
)
|
||||
|
||||
# Serialize result for LLM
|
||||
|
||||
@@ -247,6 +247,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
user_selected_filters: BaseFilters | None,
|
||||
# If the chat is part of a project
|
||||
project_id: int | None,
|
||||
# If set, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Slack context for federated Slack search (tokens fetched internally)
|
||||
slack_context: SlackContext | None = None,
|
||||
@@ -261,6 +263,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
self.document_index = document_index
|
||||
self.user_selected_filters = user_selected_filters
|
||||
self.project_id = project_id
|
||||
self.persona_id = persona_id
|
||||
self.bypass_acl = bypass_acl
|
||||
self.slack_context = slack_context
|
||||
self.enable_slack_search = enable_slack_search
|
||||
@@ -456,6 +459,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
limit=num_hits,
|
||||
),
|
||||
project_id=self.project_id,
|
||||
persona_id=self.persona_id,
|
||||
document_index=self.document_index,
|
||||
user=self.user,
|
||||
persona=self.persona,
|
||||
|
||||
@@ -6,8 +6,6 @@ aioboto3==15.1.0
|
||||
# via onyx
|
||||
aiobotocore==2.24.0
|
||||
# via aioboto3
|
||||
aiofile==3.9.0
|
||||
# via py-key-value-aio
|
||||
aiofiles==25.1.0
|
||||
# via
|
||||
# aioboto3
|
||||
@@ -42,10 +40,8 @@ anyio==4.11.0
|
||||
# httpx
|
||||
# mcp
|
||||
# openai
|
||||
# py-key-value-aio
|
||||
# sse-starlette
|
||||
# starlette
|
||||
# watchfiles
|
||||
argon2-cffi==23.1.0
|
||||
# via pwdlib
|
||||
argon2-cffi-bindings==25.1.0
|
||||
@@ -78,7 +74,9 @@ backports-tarfile==1.2.0 ; python_full_version < '3.12'
|
||||
bcrypt==4.3.0
|
||||
# via pwdlib
|
||||
beartype==0.22.6
|
||||
# via py-key-value-aio
|
||||
# via
|
||||
# py-key-value-aio
|
||||
# py-key-value-shared
|
||||
beautifulsoup4==4.12.3
|
||||
# via
|
||||
# atlassian-python-api
|
||||
@@ -112,8 +110,6 @@ cachetools==6.2.2
|
||||
# via
|
||||
# google-auth
|
||||
# py-key-value-aio
|
||||
caio==0.9.25
|
||||
# via aiofile
|
||||
celery==5.5.1
|
||||
# via onyx
|
||||
certifi==2025.11.12
|
||||
@@ -174,6 +170,7 @@ cloudpickle==3.1.2
|
||||
# via
|
||||
# dask
|
||||
# distributed
|
||||
# pydocket
|
||||
cobble==0.1.4
|
||||
# via mammoth
|
||||
cohere==5.6.1
|
||||
@@ -221,6 +218,8 @@ deprecated==1.3.1
|
||||
# pygithub
|
||||
discord-py==2.4.0
|
||||
# via onyx
|
||||
diskcache==5.6.3
|
||||
# via py-key-value-aio
|
||||
distributed==2026.1.1
|
||||
# via onyx
|
||||
distro==1.9.0
|
||||
@@ -257,6 +256,8 @@ exceptiongroup==1.3.0
|
||||
# via
|
||||
# braintrust
|
||||
# fastmcp
|
||||
fakeredis==2.33.0
|
||||
# via pydocket
|
||||
fastapi==0.128.0
|
||||
# via
|
||||
# fastapi-limiter
|
||||
@@ -272,7 +273,7 @@ fastapi-users-db-sqlalchemy==7.0.0
|
||||
# via onyx
|
||||
fastavro==1.12.1
|
||||
# via cohere
|
||||
fastmcp==3.0.2
|
||||
fastmcp==2.14.2
|
||||
# via onyx
|
||||
fastuuid==0.14.0
|
||||
# via litellm
|
||||
@@ -477,9 +478,7 @@ jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
# via jsonpatch
|
||||
jsonref==1.1.0
|
||||
# via
|
||||
# fastmcp
|
||||
# onyx
|
||||
# via onyx
|
||||
jsonschema==4.25.1
|
||||
# via
|
||||
# litellm
|
||||
@@ -514,6 +513,8 @@ locket==1.0.0
|
||||
# via
|
||||
# distributed
|
||||
# partd
|
||||
lupa==2.6
|
||||
# via fakeredis
|
||||
lxml==5.3.0
|
||||
# via
|
||||
# htmldate
|
||||
@@ -555,7 +556,7 @@ marshmallow==3.26.2
|
||||
# via dataclasses-json
|
||||
matrix-client==0.3.2
|
||||
# via zulip
|
||||
mcp==1.26.0
|
||||
mcp==1.25.0
|
||||
# via
|
||||
# claude-agent-sdk
|
||||
# fastmcp
|
||||
@@ -612,7 +613,7 @@ oauthlib==3.2.2
|
||||
# kubernetes
|
||||
# onyx
|
||||
# requests-oauthlib
|
||||
office365-rest-python-client==2.6.2
|
||||
office365-rest-python-client==2.5.9
|
||||
# via onyx
|
||||
olefile==0.47
|
||||
# via
|
||||
@@ -641,16 +642,22 @@ opensearch-py==3.0.0
|
||||
opentelemetry-api==1.39.1
|
||||
# via
|
||||
# ddtrace
|
||||
# fastmcp
|
||||
# langfuse
|
||||
# openinference-instrumentation
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
# opentelemetry-exporter-prometheus
|
||||
# opentelemetry-instrumentation
|
||||
# opentelemetry-sdk
|
||||
# opentelemetry-semantic-conventions
|
||||
# pydocket
|
||||
opentelemetry-exporter-otlp-proto-common==1.39.1
|
||||
# via opentelemetry-exporter-otlp-proto-http
|
||||
opentelemetry-exporter-otlp-proto-http==1.39.1
|
||||
# via langfuse
|
||||
opentelemetry-exporter-prometheus==0.60b1
|
||||
# via pydocket
|
||||
opentelemetry-instrumentation==0.60b1
|
||||
# via pydocket
|
||||
opentelemetry-proto==1.39.1
|
||||
# via
|
||||
# onyx
|
||||
@@ -661,15 +668,17 @@ opentelemetry-sdk==1.39.1
|
||||
# langfuse
|
||||
# openinference-instrumentation
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
# opentelemetry-exporter-prometheus
|
||||
opentelemetry-semantic-conventions==0.60b1
|
||||
# via opentelemetry-sdk
|
||||
# via
|
||||
# opentelemetry-instrumentation
|
||||
# opentelemetry-sdk
|
||||
orjson==3.11.4 ; platform_python_implementation != 'PyPy'
|
||||
# via langsmith
|
||||
packaging==24.2
|
||||
# via
|
||||
# dask
|
||||
# distributed
|
||||
# fastmcp
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# huggingface-hub
|
||||
@@ -680,6 +689,7 @@ packaging==24.2
|
||||
# langsmith
|
||||
# marshmallow
|
||||
# onnxruntime
|
||||
# opentelemetry-instrumentation
|
||||
# pytest
|
||||
# pywikibot
|
||||
pandas==2.3.3
|
||||
@@ -692,6 +702,8 @@ passlib==1.7.4
|
||||
# via onyx
|
||||
pathable==0.4.4
|
||||
# via jsonschema-path
|
||||
pathvalidate==3.3.1
|
||||
# via py-key-value-aio
|
||||
pdfminer-six==20251107
|
||||
# via markitdown
|
||||
pillow==12.1.1
|
||||
@@ -711,7 +723,9 @@ ply==3.11
|
||||
prometheus-client==0.23.1
|
||||
# via
|
||||
# onyx
|
||||
# opentelemetry-exporter-prometheus
|
||||
# prometheus-fastapi-instrumentator
|
||||
# pydocket
|
||||
prometheus-fastapi-instrumentator==7.1.0
|
||||
# via onyx
|
||||
prompt-toolkit==3.0.52
|
||||
@@ -750,8 +764,12 @@ pwdlib==0.3.0
|
||||
# via fastapi-users
|
||||
py==1.11.0
|
||||
# via retry
|
||||
py-key-value-aio==0.4.4
|
||||
# via fastmcp
|
||||
py-key-value-aio==0.3.0
|
||||
# via
|
||||
# fastmcp
|
||||
# pydocket
|
||||
py-key-value-shared==0.3.0
|
||||
# via py-key-value-aio
|
||||
pyairtable==3.0.1
|
||||
# via onyx
|
||||
pyasn1==0.6.2
|
||||
@@ -788,6 +806,8 @@ pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.12.0
|
||||
# via mcp
|
||||
pydocket==0.16.3
|
||||
# via fastmcp
|
||||
pyee==13.0.0
|
||||
# via playwright
|
||||
pygithub==2.5.0
|
||||
@@ -859,6 +879,8 @@ python-http-client==3.3.7
|
||||
# via sendgrid
|
||||
python-iso639==2025.11.16
|
||||
# via unstructured
|
||||
python-json-logger==4.0.0
|
||||
# via pydocket
|
||||
python-magic==0.4.27
|
||||
# via unstructured
|
||||
python-multipart==0.0.22
|
||||
@@ -896,7 +918,6 @@ pyyaml==6.0.3
|
||||
# via
|
||||
# dask
|
||||
# distributed
|
||||
# fastmcp
|
||||
# huggingface-hub
|
||||
# jsonschema-path
|
||||
# kubernetes
|
||||
@@ -907,8 +928,11 @@ rapidfuzz==3.13.0
|
||||
# unstructured
|
||||
redis==5.0.8
|
||||
# via
|
||||
# fakeredis
|
||||
# fastapi-limiter
|
||||
# onyx
|
||||
# py-key-value-aio
|
||||
# pydocket
|
||||
referencing==0.36.2
|
||||
# via
|
||||
# jsonschema
|
||||
@@ -983,6 +1007,7 @@ rich==14.2.0
|
||||
# via
|
||||
# cyclopts
|
||||
# fastmcp
|
||||
# pydocket
|
||||
# rich-rst
|
||||
# typer
|
||||
rich-rst==1.3.2
|
||||
@@ -1031,7 +1056,9 @@ sniffio==1.3.1
|
||||
# anyio
|
||||
# openai
|
||||
sortedcontainers==2.4.0
|
||||
# via distributed
|
||||
# via
|
||||
# distributed
|
||||
# fakeredis
|
||||
soupsieve==2.8
|
||||
# via beautifulsoup4
|
||||
sqlalchemy==2.0.15
|
||||
@@ -1097,7 +1124,9 @@ tqdm==4.67.1
|
||||
trafilatura==1.12.2
|
||||
# via onyx
|
||||
typer==0.20.0
|
||||
# via mcp
|
||||
# via
|
||||
# mcp
|
||||
# pydocket
|
||||
types-awscrt==0.28.4
|
||||
# via botocore-stubs
|
||||
types-openpyxl==3.0.4.7
|
||||
@@ -1133,10 +1162,11 @@ typing-extensions==4.15.0
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
# opentelemetry-sdk
|
||||
# opentelemetry-semantic-conventions
|
||||
# py-key-value-aio
|
||||
# py-key-value-shared
|
||||
# pyairtable
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pydocket
|
||||
# pyee
|
||||
# pygithub
|
||||
# python-docx
|
||||
@@ -1204,8 +1234,6 @@ vine==5.1.0
|
||||
# kombu
|
||||
voyageai==0.2.3
|
||||
# via onyx
|
||||
watchfiles==1.1.1
|
||||
# via fastmcp
|
||||
wcwidth==0.2.14
|
||||
# via prompt-toolkit
|
||||
webencodings==0.5.1
|
||||
@@ -1226,6 +1254,7 @@ wrapt==1.17.3
|
||||
# deprecated
|
||||
# langfuse
|
||||
# openinference-instrumentation
|
||||
# opentelemetry-instrumentation
|
||||
# unstructured
|
||||
xlrd==2.0.2
|
||||
# via markitdown
|
||||
|
||||
@@ -288,7 +288,7 @@ matplotlib-inline==0.2.1
|
||||
# via
|
||||
# ipykernel
|
||||
# ipython
|
||||
mcp==1.26.0
|
||||
mcp==1.25.0
|
||||
# via claude-agent-sdk
|
||||
multidict==6.7.0
|
||||
# via
|
||||
@@ -317,7 +317,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.6.1
|
||||
onyx-devtools==0.6.0
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
|
||||
@@ -211,7 +211,7 @@ litellm==1.81.6
|
||||
# via onyx
|
||||
markupsafe==3.0.3
|
||||
# via jinja2
|
||||
mcp==1.26.0
|
||||
mcp==1.25.0
|
||||
# via claude-agent-sdk
|
||||
monotonic==1.6
|
||||
# via posthog
|
||||
|
||||
@@ -246,7 +246,7 @@ litellm==1.81.6
|
||||
# via onyx
|
||||
markupsafe==3.0.3
|
||||
# via jinja2
|
||||
mcp==1.26.0
|
||||
mcp==1.25.0
|
||||
# via claude-agent-sdk
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
|
||||
@@ -95,6 +95,7 @@ def generate_dummy_chunk(
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
user_project=[],
|
||||
personas=[],
|
||||
access=DocumentAccess.build(
|
||||
user_emails=user_emails,
|
||||
user_groups=user_groups,
|
||||
|
||||
@@ -3,8 +3,8 @@ set -e
|
||||
|
||||
cleanup() {
|
||||
echo "Error occurred. Cleaning up..."
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Trap errors and output a message, then cleanup
|
||||
@@ -20,8 +20,8 @@ MINIO_VOLUME=${4:-""} # Default is empty if not provided
|
||||
|
||||
# Stop and remove the existing containers
|
||||
echo "Stopping and removing existing containers..."
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
|
||||
# Start the PostgreSQL container with optional volume
|
||||
echo "Starting PostgreSQL container..."
|
||||
@@ -55,10 +55,6 @@ else
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
|
||||
fi
|
||||
|
||||
# Start the Code Interpreter container
|
||||
echo "Starting Code Interpreter container..."
|
||||
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
|
||||
|
||||
# Ensure alembic runs in the correct directory (backend/)
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
|
||||
@@ -9,7 +9,6 @@ from collections.abc import AsyncGenerator
|
||||
from collections.abc import Generator
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
@@ -47,15 +46,11 @@ def mock_current_admin_user() -> MagicMock:
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client() -> Generator[TestClient, None, None]:
|
||||
# Initialize TestClient with the FastAPI app using a no-op test lifespan.
|
||||
# Patch out prometheus metrics setup to avoid "Duplicated timeseries in
|
||||
# CollectorRegistry" errors when multiple tests each create a new app
|
||||
# (prometheus registers metrics globally and rejects duplicate names).
|
||||
# Initialize TestClient with the FastAPI app using a no-op test lifespan
|
||||
get_app = fetch_versioned_implementation(
|
||||
module="onyx.main", attribute="get_application"
|
||||
)
|
||||
with patch("onyx.main.setup_prometheus_metrics"):
|
||||
app: FastAPI = get_app(lifespan_override=test_lifespan)
|
||||
app: FastAPI = get_app(lifespan_override=test_lifespan)
|
||||
|
||||
# Override the database session dependency with a mock
|
||||
# (these tests don't actually need DB access)
|
||||
|
||||
@@ -0,0 +1,526 @@
|
||||
"""
|
||||
External dependency unit tests for persona file sync.
|
||||
|
||||
Validates that:
|
||||
|
||||
1. The check_for_user_file_project_sync beat task picks up UserFiles with
|
||||
needs_persona_sync=True (not just needs_project_sync).
|
||||
|
||||
2. The process_single_user_file_project_sync worker task reads persona
|
||||
associations from the DB, passes persona_ids to the document index via
|
||||
VespaDocumentUserFields, and clears needs_persona_sync afterwards.
|
||||
|
||||
3. upsert_persona correctly marks affected UserFiles with
|
||||
needs_persona_sync=True when file associations change.
|
||||
|
||||
Uses real Redis and PostgreSQL. Document index (Vespa) calls are mocked
|
||||
since we only need to verify the arguments passed to update_single.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_project_sync_lock_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
check_for_user_file_project_sync,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_single_user_file_project_sync,
|
||||
)
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_completed_user_file(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
needs_persona_sync: bool = False,
|
||||
needs_project_sync: bool = False,
|
||||
) -> UserFile:
|
||||
"""Insert a UserFile in COMPLETED status."""
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user.id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=needs_persona_sync,
|
||||
needs_project_sync=needs_project_sync,
|
||||
chunk_count=5,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
def _create_test_persona(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
user_files: list[UserFile] | None = None,
|
||||
) -> Persona:
|
||||
"""Create a minimal Persona via direct model insert."""
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a test assistant",
|
||||
task_prompt="Answer the question",
|
||||
tools=[],
|
||||
document_sets=[],
|
||||
users=[user],
|
||||
groups=[],
|
||||
is_visible=True,
|
||||
is_public=True,
|
||||
display_priority=None,
|
||||
starter_messages=None,
|
||||
deleted=False,
|
||||
user_files=user_files or [],
|
||||
user_id=user.id,
|
||||
)
|
||||
db_session.add(persona)
|
||||
db_session.commit()
|
||||
db_session.refresh(persona)
|
||||
return persona
|
||||
|
||||
|
||||
def _link_file_to_persona(
|
||||
db_session: Session, persona: Persona, user_file: UserFile
|
||||
) -> None:
|
||||
"""Create the join table row between a persona and a user file."""
|
||||
link = Persona__UserFile(persona_id=persona.id, user_file_id=user_file.id)
|
||||
db_session.add(link)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
|
||||
"""Patch the ``app`` property on a bound Celery task."""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: check_for_user_file_project_sync picks up persona sync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckSweepIncludesPersonaSync:
|
||||
"""The beat task must pick up files needing persona sync, not just project sync."""
|
||||
|
||||
def test_persona_sync_flag_enqueues_task(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with needs_persona_sync=True (and COMPLETED) gets enqueued."""
|
||||
user = create_test_user(db_session, "persona_sweep")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
enqueued_ids = {
|
||||
call.kwargs["kwargs"]["user_file_id"]
|
||||
for call in mock_app.send_task.call_args_list
|
||||
}
|
||||
assert str(uf.id) in enqueued_ids
|
||||
|
||||
def test_neither_flag_does_not_enqueue(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with both flags False is not enqueued."""
|
||||
user = create_test_user(db_session, "no_sync")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
enqueued_ids = {
|
||||
call.kwargs["kwargs"]["user_file_id"]
|
||||
for call in mock_app.send_task.call_args_list
|
||||
}
|
||||
assert str(uf.id) not in enqueued_ids
|
||||
|
||||
def test_both_flags_enqueues_once(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file with BOTH flags True is enqueued exactly once."""
|
||||
user = create_test_user(db_session, "both_flags")
|
||||
uf = _create_completed_user_file(
|
||||
db_session, user, needs_persona_sync=True, needs_project_sync=True
|
||||
)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with _patch_task_app(check_for_user_file_project_sync, mock_app):
|
||||
check_for_user_file_project_sync.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
matching_calls = [
|
||||
call
|
||||
for call in mock_app.send_task.call_args_list
|
||||
if call.kwargs["kwargs"]["user_file_id"] == str(uf.id)
|
||||
]
|
||||
assert len(matching_calls) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: process_single_user_file_project_sync passes persona_ids to index
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_GET_SETTINGS = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.get_active_search_settings"
|
||||
)
|
||||
_PATCH_GET_INDICES = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.get_all_document_indices"
|
||||
)
|
||||
_PATCH_HTTPX_INIT = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.httpx_init_vespa_pool"
|
||||
)
|
||||
_PATCH_DISABLE_VDB = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.DISABLE_VECTOR_DB"
|
||||
)
|
||||
|
||||
|
||||
class TestSyncTaskWritesPersonaIds:
|
||||
"""The sync task reads persona associations and sends them to the index."""
|
||||
|
||||
def test_passes_persona_ids_to_update_single(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After linking a file to a persona, sync sends the persona ID."""
|
||||
user = create_test_user(db_session, "sync_persona")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = _user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
mock_doc_index.update_single.assert_called_once()
|
||||
call_kwargs = mock_doc_index.update_single.call_args.kwargs
|
||||
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
|
||||
assert persona.id in user_fields.personas
|
||||
assert call_kwargs["doc_id"] == str(uf.id)
|
||||
|
||||
def test_clears_persona_sync_flag(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After a successful sync the needs_persona_sync flag is cleared."""
|
||||
user = create_test_user(db_session, "sync_clear")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = _user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with patch(_PATCH_DISABLE_VDB, True):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is False
|
||||
|
||||
def test_passes_both_project_and_persona_ids(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file linked to both a project and a persona gets both IDs."""
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserProject
|
||||
|
||||
user = create_test_user(db_session, "sync_both")
|
||||
uf = _create_completed_user_file(
|
||||
db_session, user, needs_persona_sync=True, needs_project_sync=True
|
||||
)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
project = UserProject(user_id=user.id, name="test-project", instructions="")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
db_session.refresh(project)
|
||||
|
||||
link = Project__UserFile(project_id=project.id, user_file_id=uf.id)
|
||||
db_session.add(link)
|
||||
db_session.commit()
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = _user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
call_kwargs = mock_doc_index.update_single.call_args.kwargs
|
||||
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
|
||||
assert persona.id in user_fields.personas
|
||||
assert project.id in user_fields.user_projects
|
||||
|
||||
# Both flags should be cleared
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is False
|
||||
assert uf.needs_project_sync is False
|
||||
|
||||
def test_deleted_persona_excluded_from_ids(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A soft-deleted persona should NOT appear in the persona_ids sent to Vespa."""
|
||||
user = create_test_user(db_session, "sync_deleted")
|
||||
uf = _create_completed_user_file(db_session, user, needs_persona_sync=True)
|
||||
persona = _create_test_persona(db_session, user)
|
||||
_link_file_to_persona(db_session, persona, uf)
|
||||
|
||||
persona.deleted = True
|
||||
db_session.commit()
|
||||
|
||||
mock_doc_index = MagicMock()
|
||||
mock_search_settings = MagicMock()
|
||||
mock_search_settings.primary = MagicMock()
|
||||
mock_search_settings.secondary = None
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
lock_key = _user_file_project_sync_lock_key(str(uf.id))
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
with (
|
||||
patch(_PATCH_DISABLE_VDB, False),
|
||||
patch(_PATCH_HTTPX_INIT),
|
||||
patch(_PATCH_GET_SETTINGS, return_value=mock_search_settings),
|
||||
patch(_PATCH_GET_INDICES, return_value=[mock_doc_index]),
|
||||
):
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=str(uf.id), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
call_kwargs = mock_doc_index.update_single.call_args.kwargs
|
||||
user_fields: VespaDocumentUserFields = call_kwargs["user_fields"]
|
||||
assert persona.id not in user_fields.personas
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: upsert_persona marks files for persona sync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpsertPersonaMarksSyncFlag:
|
||||
"""upsert_persona must set needs_persona_sync on affected UserFiles."""
|
||||
|
||||
def test_creating_persona_with_files_marks_sync(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "upsert_create")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
assert uf.needs_persona_sync is False
|
||||
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf.id],
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is True
|
||||
|
||||
def test_updating_persona_files_marks_both_old_and_new(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""When file associations change, both the removed and added files are flagged."""
|
||||
user = create_test_user(db_session, "upsert_update")
|
||||
uf_old = _create_completed_user_file(db_session, user)
|
||||
uf_new = _create_completed_user_file(db_session, user)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf_old.id],
|
||||
)
|
||||
|
||||
# Clear the flag from creation so we can observe the update
|
||||
uf_old.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
# Now update the persona to swap files
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=None,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
persona_id=persona.id,
|
||||
user_file_ids=[uf_new.id],
|
||||
)
|
||||
|
||||
db_session.refresh(uf_old)
|
||||
db_session.refresh(uf_new)
|
||||
assert uf_old.needs_persona_sync is True, "Removed file should be flagged"
|
||||
assert uf_new.needs_persona_sync is True, "Added file should be flagged"
|
||||
|
||||
def test_removing_all_files_marks_old_files(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""Removing all files from a persona flags the previously associated files."""
|
||||
user = create_test_user(db_session, "upsert_remove")
|
||||
uf = _create_completed_user_file(db_session, user)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
datetime_aware=None,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
user_file_ids=[uf.id],
|
||||
)
|
||||
|
||||
uf.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=None,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
persona_id=persona.id,
|
||||
user_file_ids=[],
|
||||
)
|
||||
|
||||
db_session.refresh(uf)
|
||||
assert uf.needs_persona_sync is True
|
||||
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
External dependency unit tests for UserFileIndexingAdapter metadata writing.
|
||||
|
||||
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
|
||||
objects with both `user_project` and `personas` fields populated correctly
|
||||
based on actual DB associations.
|
||||
|
||||
Uses real PostgreSQL for UserFile/Persona/UserProject rows.
|
||||
Mocks the LLM tokenizer and file store since they are not relevant here.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserProject
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_user_file(db_session: Session, user: User) -> UserFile:
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user.id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.COMPLETED,
|
||||
chunk_count=1,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
def _create_persona(db_session: Session, user: User) -> Persona:
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
tools=[],
|
||||
document_sets=[],
|
||||
users=[user],
|
||||
groups=[],
|
||||
is_visible=True,
|
||||
is_public=True,
|
||||
display_priority=None,
|
||||
starter_messages=None,
|
||||
deleted=False,
|
||||
user_id=user.id,
|
||||
)
|
||||
db_session.add(persona)
|
||||
db_session.commit()
|
||||
db_session.refresh(persona)
|
||||
return persona
|
||||
|
||||
|
||||
def _create_project(db_session: Session, user: User) -> UserProject:
|
||||
project = UserProject(
|
||||
user_id=user.id,
|
||||
name=f"project-{uuid4().hex[:8]}",
|
||||
instructions="",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
db_session.refresh(project)
|
||||
return project
|
||||
|
||||
|
||||
def _make_index_chunk(user_file: UserFile) -> IndexChunk:
|
||||
"""Build a minimal IndexChunk whose source document ID matches the UserFile."""
|
||||
doc = Document(
|
||||
id=str(user_file.id),
|
||||
source=DocumentSource.USER_FILE,
|
||||
semantic_identifier=user_file.name,
|
||||
sections=[Section(text="test chunk content", link=None)],
|
||||
metadata={},
|
||||
)
|
||||
return IndexChunk(
|
||||
source_document=doc,
|
||||
chunk_id=0,
|
||||
blurb="test chunk",
|
||||
content="test chunk content",
|
||||
source_links={0: ""},
|
||||
section_continuation=False,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.0] * 768,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdapterWritesBothMetadataFields:
|
||||
"""build_metadata_aware_chunks must populate user_project AND personas."""
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_persona_gets_persona_id(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_persona")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona = _create_persona(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
doc = chunk.source_document
|
||||
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_project_gets_project_id(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_project")
|
||||
uf = _create_user_file(db_session, user)
|
||||
project = _create_project(db_session, user)
|
||||
|
||||
db_session.add(Project__UserFile(project_id=project.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert project.id in aware_chunk.user_project
|
||||
assert aware_chunk.personas == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_linked_to_both_gets_both_ids(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_both")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona = _create_persona(db_session, user)
|
||||
project = _create_project(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona.id, user_file_id=uf.id))
|
||||
db_session.add(Project__UserFile(project_id=project.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert project.id in aware_chunk.user_project
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_file_with_no_associations_gets_empty_lists(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "adapter_empty")
|
||||
uf = _create_user_file(db_session, user)
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert aware_chunk.personas == []
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in test"),
|
||||
)
|
||||
def test_multiple_personas_all_appear(
|
||||
self,
|
||||
_mock_llm: MagicMock,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file linked to multiple personas should have all their IDs."""
|
||||
user = create_test_user(db_session, "adapter_multi")
|
||||
uf = _create_user_file(db_session, user)
|
||||
persona_a = _create_persona(db_session, user)
|
||||
persona_b = _create_persona(db_session, user)
|
||||
|
||||
db_session.add(Persona__UserFile(persona_id=persona_a.id, user_file_id=uf.id))
|
||||
db_session.add(Persona__UserFile(persona_id=persona_b.id, user_file_id=uf.id))
|
||||
db_session.commit()
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=TEST_TENANT_ID, db_session=db_session
|
||||
)
|
||||
chunk = _make_index_chunk(uf)
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
context=context,
|
||||
)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}
|
||||
@@ -144,7 +144,8 @@ def use_mock_search_pipeline(
|
||||
auto_detect_filters: bool = False, # noqa: ARG001
|
||||
llm: LLM | None = None, # noqa: ARG001
|
||||
project_id: int | None = None, # noqa: ARG001
|
||||
# Pre-fetched data (used by SearchTool to avoid DB access in parallel)
|
||||
persona_id: int | None = None, # noqa: ARG001
|
||||
# Pre-fetched data (used by SearchTool to avoid DB access in parallel calls)
|
||||
acl_filters: list[str] | None = None, # noqa: ARG001
|
||||
embedding_model: EmbeddingModel | None = None, # noqa: ARG001
|
||||
prefetched_federated_retrieval_infos: ( # noqa: ARG001
|
||||
|
||||
@@ -990,27 +990,6 @@ class _MockCIHandler(BaseHTTPRequestHandler):
|
||||
self._respond_json(
|
||||
200, {"file_id": f"mock-ci-file-{self.server._file_counter}"}
|
||||
)
|
||||
elif self.path == "/v1/execute/stream":
|
||||
if self.server.streaming_enabled:
|
||||
self._respond_sse(
|
||||
[
|
||||
(
|
||||
"output",
|
||||
{"stream": "stdout", "data": "mock output\n"},
|
||||
),
|
||||
(
|
||||
"result",
|
||||
{
|
||||
"exit_code": 0,
|
||||
"timed_out": False,
|
||||
"duration_ms": 50,
|
||||
"files": [],
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
self._respond_json(404, {"error": "not found"})
|
||||
elif self.path == "/v1/execute":
|
||||
self._respond_json(
|
||||
200,
|
||||
@@ -1048,17 +1027,6 @@ class _MockCIHandler(BaseHTTPRequestHandler):
|
||||
self.end_headers()
|
||||
self.wfile.write(payload)
|
||||
|
||||
def _respond_sse(self, events: list[tuple[str, dict[str, Any]]]) -> None:
|
||||
frames = []
|
||||
for event_type, data in events:
|
||||
frames.append(f"event: {event_type}\ndata: {json.dumps(data)}\n\n")
|
||||
payload = "".join(frames).encode()
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/event-stream")
|
||||
self.send_header("Content-Length", str(len(payload)))
|
||||
self.end_headers()
|
||||
self.wfile.write(payload)
|
||||
|
||||
def log_message(self, format: str, *args: Any) -> None: # noqa: A002
|
||||
pass
|
||||
|
||||
@@ -1070,7 +1038,6 @@ class MockCodeInterpreterServer(HTTPServer):
|
||||
super().__init__(("localhost", 0), _MockCIHandler)
|
||||
self.captured_requests: list[CapturedRequest] = []
|
||||
self._file_counter = 0
|
||||
self.streaming_enabled: bool = True
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
@@ -1201,19 +1168,17 @@ def test_code_interpreter_receives_chat_files(
|
||||
finally:
|
||||
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
|
||||
|
||||
# Verify: file uploaded, code executed via streaming, staged file cleaned up
|
||||
# Verify: file uploaded, code executed, staged file cleaned up
|
||||
assert len(mock_ci_server.get_requests(method="POST", path="/v1/files")) == 1
|
||||
assert (
|
||||
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
|
||||
)
|
||||
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
|
||||
|
||||
delete_requests = mock_ci_server.get_requests(method="DELETE")
|
||||
assert len(delete_requests) == 1
|
||||
assert delete_requests[0].path.startswith("/v1/files/")
|
||||
|
||||
execute_body = mock_ci_server.get_requests(
|
||||
method="POST", path="/v1/execute/stream"
|
||||
)[0].json_body()
|
||||
execute_body = mock_ci_server.get_requests(method="POST", path="/v1/execute")[
|
||||
0
|
||||
].json_body()
|
||||
assert execute_body["code"] == code
|
||||
assert len(execute_body["files"]) == 1
|
||||
assert execute_body["files"][0]["path"] == "data.csv"
|
||||
@@ -1319,9 +1284,7 @@ def test_code_interpreter_replay_packets_include_code_and_output(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert (
|
||||
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
|
||||
)
|
||||
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
|
||||
|
||||
# The response contains `packets` — a list of packet-lists, one per
|
||||
# assistant message. We should have exactly one assistant message.
|
||||
@@ -1350,76 +1313,3 @@ def test_code_interpreter_replay_packets_include_code_and_output(
|
||||
delta_obj = delta_packets[0].obj
|
||||
assert isinstance(delta_obj, PythonToolDelta)
|
||||
assert "mock output" in delta_obj.stdout
|
||||
|
||||
|
||||
def test_code_interpreter_streaming_fallback_to_batch(
|
||||
db_session: Session,
|
||||
mock_ci_server: MockCodeInterpreterServer,
|
||||
_attach_python_tool_to_default_persona: None,
|
||||
initialize_file_store: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When the streaming endpoint is not available (older code-interpreter),
|
||||
execute_streaming should fall back to the batch /v1/execute endpoint."""
|
||||
mock_ci_server.captured_requests.clear()
|
||||
mock_ci_server._file_counter = 0
|
||||
mock_ci_server.streaming_enabled = False
|
||||
mock_url = mock_ci_server.url
|
||||
|
||||
user = create_test_user(db_session, "ci_fallback_test")
|
||||
chat_session = create_chat_session(db_session=db_session, user=user)
|
||||
|
||||
code = 'print("fallback test")'
|
||||
msg_req = SendMessageRequest(
|
||||
message="Print fallback test",
|
||||
chat_session_id=chat_session.id,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
original_defaults = ci_mod.CodeInterpreterClient.__init__.__defaults__
|
||||
with (
|
||||
use_mock_llm() as mock_llm,
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
mock_url,
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.python.code_interpreter_client.CODE_INTERPRETER_BASE_URL",
|
||||
mock_url,
|
||||
),
|
||||
):
|
||||
mock_llm.add_response(
|
||||
LLMToolCallResponse(
|
||||
tool_name="python",
|
||||
tool_call_id="call_fallback",
|
||||
tool_call_argument_tokens=[json.dumps({"code": code})],
|
||||
)
|
||||
)
|
||||
mock_llm.forward_till_end()
|
||||
|
||||
ci_mod.CodeInterpreterClient.__init__.__defaults__ = (mock_url,)
|
||||
try:
|
||||
packets = list(
|
||||
handle_stream_message_objects(
|
||||
new_msg_req=msg_req, user=user, db_session=db_session
|
||||
)
|
||||
)
|
||||
finally:
|
||||
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
|
||||
mock_ci_server.streaming_enabled = True
|
||||
|
||||
# Streaming was attempted first (returned 404), then fell back to batch
|
||||
assert (
|
||||
len(mock_ci_server.get_requests(method="POST", path="/v1/execute/stream")) == 1
|
||||
)
|
||||
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
|
||||
|
||||
# Verify output still made it through
|
||||
delta_packets = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, PythonToolDelta)
|
||||
]
|
||||
assert len(delta_packets) >= 1
|
||||
first_delta = delta_packets[0].obj
|
||||
assert isinstance(first_delta, PythonToolDelta)
|
||||
assert "mock output" in first_delta.stdout
|
||||
|
||||
@@ -5,17 +5,22 @@ from fastapi import FastAPI
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.server.auth import StaticTokenVerifier
|
||||
from fastmcp.server.server import FunctionTool
|
||||
|
||||
|
||||
def make_many_tools(mcp: FastMCP) -> None:
|
||||
def make_tool(i: int) -> None:
|
||||
def make_many_tools(mcp: FastMCP) -> list[FunctionTool]:
|
||||
def make_tool(i: int) -> FunctionTool:
|
||||
@mcp.tool(name=f"tool_{i}", description=f"Get secret value {i}")
|
||||
def tool_name(name: str) -> str: # noqa: ARG001
|
||||
"""Get secret value."""
|
||||
return f"Secret value {200 - i}!"
|
||||
|
||||
return tool_name
|
||||
|
||||
tools = []
|
||||
for i in range(100):
|
||||
make_tool(i)
|
||||
tools.append(make_tool(i))
|
||||
return tools
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -28,6 +28,7 @@ from fastmcp import FastMCP
|
||||
from fastmcp.server.auth import AccessToken
|
||||
from fastmcp.server.auth import TokenVerifier
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
from fastmcp.server.server import FunctionTool
|
||||
|
||||
# Google's tokeninfo endpoint for validating access tokens
|
||||
GOOGLE_TOKENINFO_URL = "https://oauth2.googleapis.com/tokeninfo"
|
||||
@@ -147,19 +148,24 @@ class GoogleOAuthTokenVerifier(TokenVerifier):
|
||||
await self._http_client.aclose()
|
||||
|
||||
|
||||
def make_tools(mcp: FastMCP) -> None:
|
||||
def make_tools(mcp: FastMCP) -> list[FunctionTool]:
|
||||
"""Create test tools for the MCP server."""
|
||||
tools: list[FunctionTool] = []
|
||||
|
||||
@mcp.tool(name="echo", description="Echo back the input message")
|
||||
def echo(message: str) -> str:
|
||||
"""Echo the message back to the caller."""
|
||||
return f"You said: {message}"
|
||||
|
||||
tools.append(echo)
|
||||
|
||||
@mcp.tool(name="get_secret", description="Get a secret value (requires auth)")
|
||||
def get_secret(secret_name: str) -> str:
|
||||
"""Get a secret value. This proves the token was validated."""
|
||||
return f"Secret value for '{secret_name}': super-secret-value-12345"
|
||||
|
||||
tools.append(get_secret)
|
||||
|
||||
@mcp.tool(name="whoami", description="Get information about the authenticated user")
|
||||
async def whoami() -> dict[str, Any]:
|
||||
"""Get information about the authenticated user from their Google token."""
|
||||
@@ -176,6 +182,9 @@ def make_tools(mcp: FastMCP) -> None:
|
||||
"access_type": tok.claims.get("access_type"),
|
||||
}
|
||||
|
||||
tools.append(whoami)
|
||||
|
||||
# Add some numbered tools for testing tool discovery
|
||||
for i in range(5):
|
||||
|
||||
@mcp.tool(name=f"oauth_tool_{i}", description=f"Test tool number {i}")
|
||||
@@ -183,6 +192,10 @@ def make_tools(mcp: FastMCP) -> None:
|
||||
"""A numbered test tool."""
|
||||
return f"Tool {_i} says hello to {name}!"
|
||||
|
||||
tools.append(numbered_tool)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(sys.argv[1] if len(sys.argv) > 1 else "8006")
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import sys
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.server.server import FunctionTool
|
||||
|
||||
mcp = FastMCP("My HTTP MCP")
|
||||
|
||||
@@ -12,15 +13,19 @@ def hello(name: str) -> str:
|
||||
return f"Hello, {name}!"
|
||||
|
||||
|
||||
def make_many_tools() -> None:
|
||||
def make_tool(i: int) -> None:
|
||||
def make_many_tools() -> list[FunctionTool]:
|
||||
def make_tool(i: int) -> FunctionTool:
|
||||
@mcp.tool(name=f"tool_{i}", description=f"Get secret value {i}")
|
||||
def tool_name(name: str) -> str: # noqa: ARG001
|
||||
"""Get secret value."""
|
||||
return f"Secret value {100 - i}!"
|
||||
|
||||
return tool_name
|
||||
|
||||
tools = []
|
||||
for i in range(100):
|
||||
make_tool(i)
|
||||
tools.append(make_tool(i))
|
||||
return tools
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -15,6 +15,7 @@ from fastapi.responses import Response
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.server.auth.providers.jwt import JWTVerifier
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
from fastmcp.server.server import FunctionTool
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
# uncomment for debug logs
|
||||
@@ -36,15 +37,18 @@ Enable authorization code and store the client id and secret.
|
||||
"""
|
||||
|
||||
|
||||
def make_many_tools(mcp: FastMCP) -> None:
|
||||
def make_tool(i: int) -> None:
|
||||
def make_many_tools(mcp: FastMCP) -> list[FunctionTool]:
|
||||
def make_tool(i: int) -> FunctionTool:
|
||||
@mcp.tool(name=f"tool_{i}", description=f"Get secret value {i}")
|
||||
def tool_name(name: str) -> str: # noqa: ARG001
|
||||
"""Get secret value."""
|
||||
return f"Secret value {500 - i}!"
|
||||
|
||||
return tool_name
|
||||
|
||||
tools = []
|
||||
for i in range(100):
|
||||
make_tool(i)
|
||||
tools.append(make_tool(i))
|
||||
|
||||
@mcp.tool
|
||||
async def whoami() -> dict[str, Any]:
|
||||
@@ -55,6 +59,9 @@ def make_many_tools(mcp: FastMCP) -> None:
|
||||
"claims": tok.claims if tok else {},
|
||||
}
|
||||
|
||||
tools.append(whoami)
|
||||
return tools
|
||||
|
||||
|
||||
# ---------- FASTAPI APP ----------
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from fastmcp import FastMCP
|
||||
from fastmcp.server.auth.auth import AccessToken
|
||||
from fastmcp.server.auth.auth import TokenVerifier
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
from fastmcp.server.server import FunctionTool
|
||||
|
||||
# pip install fastmcp bcrypt
|
||||
|
||||
@@ -92,15 +93,19 @@ class ApiKeyVerifier(TokenVerifier):
|
||||
# ---- server -----------------------------------------------------------------
|
||||
|
||||
|
||||
def make_many_tools(mcp: FastMCP) -> None:
|
||||
def make_tool(i: int) -> None:
|
||||
def make_many_tools(mcp: FastMCP) -> list[FunctionTool]:
|
||||
def make_tool(i: int) -> FunctionTool:
|
||||
@mcp.tool(name=f"tool_{i}", description=f"Get secret value {i}")
|
||||
def tool_name(name: str) -> str: # noqa: ARG001
|
||||
"""Get secret value."""
|
||||
return f"Secret value {400 - i}!"
|
||||
|
||||
return tool_name
|
||||
|
||||
tools = []
|
||||
for i in range(100):
|
||||
make_tool(i)
|
||||
tools.append(make_tool(i))
|
||||
return tools
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -106,13 +106,13 @@ class TestGuildDataIsolation:
|
||||
|
||||
# Create admin user for tenant 1
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"discord_admin1_{unique}@example.com",
|
||||
email=f"discord_admin1+{unique}@example.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
|
||||
|
||||
# Create admin user for tenant 2
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"discord_admin2_{unique}@example.com",
|
||||
email=f"discord_admin2+{unique}@example.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
|
||||
|
||||
@@ -170,10 +170,10 @@ class TestGuildDataIsolation:
|
||||
|
||||
# Create admin users for two tenants
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"discord_list1_{unique}@example.com",
|
||||
email=f"discord_list1+{unique}@example.com",
|
||||
)
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"discord_list2_{unique}@example.com",
|
||||
email=f"discord_list2+{unique}@example.com",
|
||||
)
|
||||
|
||||
# Create 1 guild in tenant 1
|
||||
@@ -350,10 +350,10 @@ class TestGuildAccessIsolation:
|
||||
|
||||
# Create admin users for two tenants
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"discord_access1_{unique}@example.com",
|
||||
email=f"discord_access1+{unique}@example.com",
|
||||
)
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"discord_access2_{unique}@example.com",
|
||||
email=f"discord_access2+{unique}@example.com",
|
||||
)
|
||||
|
||||
# Create a guild in tenant 1
|
||||
|
||||
@@ -21,7 +21,7 @@ def test_admin_can_invite_users(reset_multitenant: None) -> None: # noqa: ARG00
|
||||
|
||||
# Admin user invites the previously registered and non-registered user
|
||||
UserManager.invite_user(invited_user.email, admin_user)
|
||||
UserManager.invite_user(f"{INVITED_BASIC_USER}_{unique}@example.com", admin_user)
|
||||
UserManager.invite_user(f"{INVITED_BASIC_USER}+{unique}@example.com", admin_user)
|
||||
|
||||
# Verify users are in the invited users list
|
||||
invited_users = UserManager.get_invited_users(admin_user)
|
||||
@@ -40,7 +40,7 @@ def test_non_registered_user_gets_basic_role(
|
||||
assert UserManager.is_role(admin_user, UserRole.ADMIN)
|
||||
|
||||
# Admin user invites a non-registered user
|
||||
invited_email = f"{INVITED_BASIC_USER}_{unique}@example.com"
|
||||
invited_email = f"{INVITED_BASIC_USER}+{unique}@example.com"
|
||||
UserManager.invite_user(invited_email, admin_user)
|
||||
|
||||
# Non-registered user registers
|
||||
@@ -58,7 +58,7 @@ def test_user_can_accept_invitation(reset_multitenant: None) -> None: # noqa: A
|
||||
assert UserManager.is_role(admin_user, UserRole.ADMIN)
|
||||
|
||||
# Create a user to be invited
|
||||
invited_user_email = f"invited_user_{unique}@example.com"
|
||||
invited_user_email = f"invited_user+{unique}@example.com"
|
||||
|
||||
# User registers with the same email as the invitation
|
||||
invited_user: DATestUser = UserManager.create(
|
||||
|
||||
@@ -20,13 +20,13 @@ def setup_test_tenants(reset_multitenant: None) -> dict[str, Any]: # noqa: ARG0
|
||||
unique = uuid4().hex
|
||||
# Creating an admin user for Tenant 1
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"admin_{unique}@example.com",
|
||||
email=f"admin+{unique}@example.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
|
||||
|
||||
# Create Tenant 2 and its Admin User
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"admin2_{unique}@example.com",
|
||||
email=f"admin2+{unique}@example.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
|
||||
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
"""
|
||||
Integration tests for onyx.db.engine.tenant_utils.get_schemas_needing_migration.
|
||||
|
||||
These tests require a live database and exercise the function directly,
|
||||
independent of the alembic migration runner script.
|
||||
|
||||
Usage:
|
||||
pytest tests/integration/multitenant_tests/test_get_schemas_needing_migration.py -v
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
|
||||
|
||||
_BACKEND_DIR = __file__[: __file__.index("/tests/")]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine() -> Engine:
|
||||
return SqlEngine.get_engine()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_head_rev() -> str:
|
||||
result = subprocess.run(
|
||||
["alembic", "heads", "--resolve-dependencies"],
|
||||
cwd=_BACKEND_DIR,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
assert (
|
||||
result.returncode == 0
|
||||
), f"alembic heads failed (exit {result.returncode}):\n{result.stdout}"
|
||||
rev = result.stdout.strip().split()[0]
|
||||
assert len(rev) > 0
|
||||
return rev
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_schema_at_head(
|
||||
engine: Engine, current_head_rev: str
|
||||
) -> Generator[str, None, None]:
|
||||
"""Tenant schema with alembic_version already at head — should be excluded."""
|
||||
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
|
||||
conn.execute(
|
||||
text(
|
||||
f'CREATE TABLE "{schema}".alembic_version '
|
||||
f"(version_num VARCHAR(32) NOT NULL)"
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(f'INSERT INTO "{schema}".alembic_version (version_num) VALUES (:rev)'),
|
||||
{"rev": current_head_rev},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
|
||||
"""Tenant schema with no tables — should be included (needs migration)."""
|
||||
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
|
||||
conn.commit()
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_schema_stale_rev(engine: Engine) -> Generator[str, None, None]:
|
||||
"""Tenant schema with a non-head revision — should be included (needs migration)."""
|
||||
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
|
||||
conn.execute(
|
||||
text(
|
||||
f'CREATE TABLE "{schema}".alembic_version '
|
||||
f"(version_num VARCHAR(32) NOT NULL)"
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
f'INSERT INTO "{schema}".alembic_version (version_num) '
|
||||
f"VALUES ('stalerev000000000000')"
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_classifies_all_cases(
|
||||
current_head_rev: str,
|
||||
tenant_schema_at_head: str,
|
||||
tenant_schema_empty: str,
|
||||
tenant_schema_stale_rev: str,
|
||||
) -> None:
|
||||
"""Correctly classifies all three schema states:
|
||||
- at head → excluded
|
||||
- no table → included (needs migration)
|
||||
- stale rev → included (needs migration)
|
||||
"""
|
||||
all_schemas = [tenant_schema_at_head, tenant_schema_empty, tenant_schema_stale_rev]
|
||||
result = get_schemas_needing_migration(all_schemas, current_head_rev)
|
||||
|
||||
assert tenant_schema_at_head not in result
|
||||
assert tenant_schema_empty in result
|
||||
assert tenant_schema_stale_rev in result
|
||||
|
||||
|
||||
def test_idempotent(
|
||||
current_head_rev: str,
|
||||
tenant_schema_at_head: str,
|
||||
tenant_schema_empty: str,
|
||||
) -> None:
|
||||
"""Calling the function twice returns the same result.
|
||||
|
||||
Verifies that the DROP TABLE IF EXISTS guards correctly clean up temp
|
||||
tables so a second call succeeds even if the first left state behind.
|
||||
"""
|
||||
schemas = [tenant_schema_at_head, tenant_schema_empty]
|
||||
|
||||
first = get_schemas_needing_migration(schemas, current_head_rev)
|
||||
second = get_schemas_needing_migration(schemas, current_head_rev)
|
||||
|
||||
assert first == second
|
||||
|
||||
|
||||
def test_empty_input(current_head_rev: str) -> None:
|
||||
"""An empty input list returns immediately without touching the DB."""
|
||||
assert get_schemas_needing_migration([], current_head_rev) == []
|
||||
@@ -1,130 +0,0 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.tool import ToolManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
CODE_INTERPRETER_URL = f"{API_SERVER_URL}/admin/code-interpreter"
|
||||
CODE_INTERPRETER_HEALTH_URL = f"{CODE_INTERPRETER_URL}/health"
|
||||
PYTHON_TOOL_NAME = "python"
|
||||
|
||||
|
||||
def test_get_code_interpreter_health_as_admin(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Health endpoint should return a JSON object with a 'healthy' boolean."""
|
||||
response = requests.get(
|
||||
CODE_INTERPRETER_HEALTH_URL,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "healthy" in data
|
||||
assert isinstance(data["healthy"], bool)
|
||||
|
||||
|
||||
def test_get_code_interpreter_status_as_admin(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""GET endpoint should return a JSON object with an 'enabled' boolean."""
|
||||
response = requests.get(
|
||||
CODE_INTERPRETER_URL,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "enabled" in data
|
||||
assert isinstance(data["enabled"], bool)
|
||||
|
||||
|
||||
def test_update_code_interpreter_disable_and_enable(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""PUT endpoint should update the enabled flag and persist across reads."""
|
||||
# Disable
|
||||
response = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": False},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify disabled
|
||||
response = requests.get(
|
||||
CODE_INTERPRETER_URL,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["enabled"] is False
|
||||
|
||||
# Re-enable
|
||||
response = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": True},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify enabled
|
||||
response = requests.get(
|
||||
CODE_INTERPRETER_URL,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["enabled"] is True
|
||||
|
||||
|
||||
def test_code_interpreter_endpoints_require_admin(
|
||||
basic_user: DATestUser,
|
||||
) -> None:
|
||||
"""All code interpreter endpoints should reject non-admin users."""
|
||||
health_response = requests.get(
|
||||
CODE_INTERPRETER_HEALTH_URL,
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert health_response.status_code == 403
|
||||
|
||||
get_response = requests.get(
|
||||
CODE_INTERPRETER_URL,
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert get_response.status_code == 403
|
||||
|
||||
put_response = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": True},
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert put_response.status_code == 403
|
||||
|
||||
|
||||
def test_python_tool_hidden_from_tool_list_when_disabled(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""When code interpreter is disabled, the Python tool should not appear
|
||||
in the GET /tool response (i.e. the frontend tool list)."""
|
||||
# Disable
|
||||
response = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": False},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Python tool should not be in the tool list
|
||||
tools = ToolManager.list_tools(user_performing_action=admin_user)
|
||||
tool_names = [t.name for t in tools]
|
||||
assert PYTHON_TOOL_NAME not in tool_names
|
||||
|
||||
# Re-enable
|
||||
response = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": True},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Python tool should reappear
|
||||
tools = ToolManager.list_tools(user_performing_action=admin_user)
|
||||
tool_names = [t.name for t in tools]
|
||||
assert PYTHON_TOOL_NAME in tool_names
|
||||
@@ -1,322 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from onyx.configs import app_configs
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.tools.constants import SEARCH_TOOL_ID
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.tool import ToolManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.test_models import ToolName
|
||||
|
||||
|
||||
_ENV_PROVIDER = "NIGHTLY_LLM_PROVIDER"
|
||||
_ENV_MODELS = "NIGHTLY_LLM_MODELS"
|
||||
_ENV_API_KEY = "NIGHTLY_LLM_API_KEY"
|
||||
_ENV_API_BASE = "NIGHTLY_LLM_API_BASE"
|
||||
_ENV_CUSTOM_CONFIG_JSON = "NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
_ENV_STRICT = "NIGHTLY_LLM_STRICT"
|
||||
|
||||
|
||||
class NightlyProviderConfig(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
provider: str
|
||||
model_names: list[str]
|
||||
api_key: str | None
|
||||
api_base: str | None
|
||||
custom_config: dict[str, str] | None
|
||||
strict: bool
|
||||
|
||||
|
||||
def _env_true(env_var: str, default: bool = False) -> bool:
|
||||
value = os.environ.get(env_var)
|
||||
if value is None:
|
||||
return default
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _split_csv_env(env_var: str) -> list[str]:
|
||||
return [
|
||||
part.strip() for part in os.environ.get(env_var, "").split(",") if part.strip()
|
||||
]
|
||||
|
||||
|
||||
def _load_provider_config() -> NightlyProviderConfig:
|
||||
provider = os.environ.get(_ENV_PROVIDER, "").strip().lower()
|
||||
model_names = _split_csv_env(_ENV_MODELS)
|
||||
api_key = os.environ.get(_ENV_API_KEY) or None
|
||||
api_base = os.environ.get(_ENV_API_BASE) or None
|
||||
strict = _env_true(_ENV_STRICT, default=False)
|
||||
|
||||
custom_config: dict[str, str] | None = None
|
||||
custom_config_json = os.environ.get(_ENV_CUSTOM_CONFIG_JSON, "").strip()
|
||||
if custom_config_json:
|
||||
parsed = json.loads(custom_config_json)
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError(f"{_ENV_CUSTOM_CONFIG_JSON} must be a JSON object")
|
||||
custom_config = {str(key): str(value) for key, value in parsed.items()}
|
||||
|
||||
if provider == "ollama_chat" and api_key and not custom_config:
|
||||
custom_config = {"OLLAMA_API_KEY": api_key}
|
||||
|
||||
return NightlyProviderConfig(
|
||||
provider=provider,
|
||||
model_names=model_names,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
|
||||
def _skip_or_fail(strict: bool, message: str) -> None:
|
||||
if strict:
|
||||
pytest.fail(message)
|
||||
pytest.skip(message)
|
||||
|
||||
|
||||
def _validate_provider_config(config: NightlyProviderConfig) -> None:
|
||||
if not config.provider:
|
||||
_skip_or_fail(strict=config.strict, message=f"{_ENV_PROVIDER} must be set")
|
||||
|
||||
if not config.model_names:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=f"{_ENV_MODELS} must include at least one model",
|
||||
)
|
||||
|
||||
if config.provider != "ollama_chat" and not config.api_key:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(f"{_ENV_API_KEY} is required for provider '{config.provider}'"),
|
||||
)
|
||||
|
||||
if config.provider == "ollama_chat" and not (
|
||||
config.api_base or _default_api_base_for_provider(config.provider)
|
||||
):
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(f"{_ENV_API_BASE} is required for provider '{config.provider}'"),
|
||||
)
|
||||
|
||||
|
||||
def _assert_integration_mode_enabled() -> None:
|
||||
assert (
|
||||
app_configs.INTEGRATION_TESTS_MODE is True
|
||||
), "Integration tests require INTEGRATION_TESTS_MODE=true."
|
||||
|
||||
|
||||
def _seed_connector_for_search_tool(admin_user: DATestUser) -> None:
|
||||
# SearchTool is only exposed when at least one non-default connector exists.
|
||||
CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
|
||||
def _get_internal_search_tool_id(admin_user: DATestUser) -> int:
|
||||
tools = ToolManager.list_tools(user_performing_action=admin_user)
|
||||
for tool in tools:
|
||||
if tool.in_code_tool_id == SEARCH_TOOL_ID:
|
||||
return tool.id
|
||||
raise AssertionError("SearchTool must exist for this test")
|
||||
|
||||
|
||||
def _default_api_base_for_provider(provider: str) -> str | None:
|
||||
if provider == "openrouter":
|
||||
return "https://openrouter.ai/api/v1"
|
||||
if provider == "ollama_chat":
|
||||
# host.docker.internal works when tests are running inside the integration test container.
|
||||
return "http://host.docker.internal:11434"
|
||||
return None
|
||||
|
||||
|
||||
def _create_provider_payload(
|
||||
provider: str,
|
||||
provider_name: str,
|
||||
model_name: str,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
custom_config: dict[str, str] | None,
|
||||
) -> dict:
|
||||
return {
|
||||
"name": provider_name,
|
||||
"provider": provider,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"custom_config": custom_config,
|
||||
"default_model_name": model_name,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
"personas": [],
|
||||
"model_configurations": [{"name": model_name, "is_visible": True}],
|
||||
"api_key_changed": bool(api_key),
|
||||
"custom_config_changed": bool(custom_config),
|
||||
}
|
||||
|
||||
|
||||
def _ensure_provider_is_default(provider_id: int, admin_user: DATestUser) -> None:
|
||||
list_response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
list_response.raise_for_status()
|
||||
providers = list_response.json()
|
||||
|
||||
current_default = next(
|
||||
(provider for provider in providers if provider.get("is_default_provider")),
|
||||
None,
|
||||
)
|
||||
assert (
|
||||
current_default is not None
|
||||
), "Expected a default provider after setting provider as default"
|
||||
assert (
|
||||
current_default["id"] == provider_id
|
||||
), f"Expected provider {provider_id} to be default, found {current_default['id']}"
|
||||
|
||||
|
||||
def _run_chat_assertions(
|
||||
admin_user: DATestUser,
|
||||
search_tool_id: int,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
) -> None:
|
||||
last_error: str | None = None
|
||||
# Retry once to reduce transient nightly flakes due provider-side blips.
|
||||
for attempt in range(1, 3):
|
||||
chat_session = ChatSessionManager.create(user_performing_action=admin_user)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message=(
|
||||
"Use internal_search to search for 'nightly-provider-regression-sentinel', "
|
||||
"then summarize the result in one short sentence."
|
||||
),
|
||||
user_performing_action=admin_user,
|
||||
forced_tool_ids=[search_tool_id],
|
||||
)
|
||||
|
||||
if response.error is None:
|
||||
used_internal_search = any(
|
||||
used_tool.tool_name == ToolName.INTERNAL_SEARCH
|
||||
for used_tool in response.used_tools
|
||||
)
|
||||
debug_has_internal_search = any(
|
||||
debug_tool_call.tool_name == "internal_search"
|
||||
for debug_tool_call in response.tool_call_debug
|
||||
)
|
||||
has_answer = bool(response.full_message.strip())
|
||||
|
||||
if used_internal_search and debug_has_internal_search and has_answer:
|
||||
return
|
||||
|
||||
last_error = (
|
||||
f"attempt={attempt} provider={provider} model={model_name} "
|
||||
f"used_internal_search={used_internal_search} "
|
||||
f"debug_internal_search={debug_has_internal_search} "
|
||||
f"has_answer={has_answer} "
|
||||
f"tool_call_debug={response.tool_call_debug}"
|
||||
)
|
||||
else:
|
||||
last_error = (
|
||||
f"attempt={attempt} provider={provider} model={model_name} "
|
||||
f"stream_error={response.error.error}"
|
||||
)
|
||||
|
||||
time.sleep(attempt)
|
||||
|
||||
pytest.fail(f"Chat/tool-call assertions failed: {last_error}")
|
||||
|
||||
|
||||
def _create_and_test_provider_for_model(
|
||||
admin_user: DATestUser,
|
||||
config: NightlyProviderConfig,
|
||||
model_name: str,
|
||||
search_tool_id: int,
|
||||
) -> None:
|
||||
provider_name = f"nightly-{config.provider}-{uuid4().hex[:12]}"
|
||||
resolved_api_base = config.api_base or _default_api_base_for_provider(
|
||||
config.provider
|
||||
)
|
||||
|
||||
provider_payload = _create_provider_payload(
|
||||
provider=config.provider,
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=resolved_api_base,
|
||||
custom_config=config.custom_config,
|
||||
)
|
||||
|
||||
test_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/test",
|
||||
headers=admin_user.headers,
|
||||
json=provider_payload,
|
||||
)
|
||||
assert test_response.status_code == 200, (
|
||||
f"Provider test endpoint failed for provider={config.provider} "
|
||||
f"model={model_name}: {test_response.status_code} {test_response.text}"
|
||||
)
|
||||
|
||||
create_response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json=provider_payload,
|
||||
)
|
||||
assert create_response.status_code == 200, (
|
||||
f"Provider creation failed for provider={config.provider} "
|
||||
f"model={model_name}: {create_response.status_code} {create_response.text}"
|
||||
)
|
||||
provider_id = create_response.json()["id"]
|
||||
|
||||
try:
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}/default",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert set_default_response.status_code == 200, (
|
||||
f"Setting default provider failed for provider={config.provider} "
|
||||
f"model={model_name}: {set_default_response.status_code} "
|
||||
f"{set_default_response.text}"
|
||||
)
|
||||
|
||||
_ensure_provider_is_default(provider_id=provider_id, admin_user=admin_user)
|
||||
_run_chat_assertions(
|
||||
admin_user=admin_user,
|
||||
search_tool_id=search_tool_id,
|
||||
provider=config.provider,
|
||||
model_name=model_name,
|
||||
)
|
||||
finally:
|
||||
requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
|
||||
|
||||
def test_nightly_provider_chat_workflow(admin_user: DATestUser) -> None:
|
||||
"""Nightly regression test for provider setup + default selection + chat tool calls."""
|
||||
_assert_integration_mode_enabled()
|
||||
config = _load_provider_config()
|
||||
_validate_provider_config(config)
|
||||
|
||||
_seed_connector_for_search_tool(admin_user)
|
||||
search_tool_id = _get_internal_search_tool_id(admin_user)
|
||||
|
||||
for model_name in config.model_names:
|
||||
_create_and_test_provider_for_model(
|
||||
admin_user=admin_user,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Integration tests for the unified persona file context flow.
|
||||
|
||||
End-to-end tests that verify:
|
||||
1. Files can be uploaded and attached to a persona via API.
|
||||
2. The persona correctly reports its attached files.
|
||||
3. A chat session with a file-bearing persona processes without error.
|
||||
4. Precedence: custom persona files take priority over project files when
|
||||
the chat session is inside a project.
|
||||
|
||||
These tests run against a real Onyx deployment (all services running).
|
||||
File processing is asynchronous, so we poll the file status endpoint
|
||||
until files reach COMPLETED before chatting.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.file import FileManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.managers.project import ProjectManager
|
||||
from tests.integration.common_utils.test_file_utils import create_test_text_file
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FILE_PROCESSING_POLL_INTERVAL = 2
|
||||
|
||||
|
||||
def _poll_file_statuses(
|
||||
user_file_ids: list[str],
|
||||
user: DATestUser,
|
||||
target_status: UserFileStatus = UserFileStatus.COMPLETED,
|
||||
timeout: int = MAX_DELAY,
|
||||
) -> None:
|
||||
"""Block until all files reach the target status or timeout expires."""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/user/projects/file/statuses",
|
||||
json={"file_ids": user_file_ids},
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
statuses = response.json()
|
||||
if all(f["status"] == target_status.value for f in statuses):
|
||||
return
|
||||
time.sleep(FILE_PROCESSING_POLL_INTERVAL)
|
||||
raise TimeoutError(
|
||||
f"Files {user_file_ids} did not reach {target_status.value} "
|
||||
f"within {timeout}s"
|
||||
)
|
||||
|
||||
|
||||
def _get_persona_detail(persona_id: int, user: DATestUser) -> dict:
|
||||
"""Fetch the full persona snapshot from the API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona/{persona_id}",
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _create_chat_session_with_project(
|
||||
persona_id: int, project_id: int, user: DATestUser
|
||||
) -> dict:
|
||||
"""Create a chat session explicitly inside a project."""
|
||||
req = ChatSessionCreationRequest(persona_id=persona_id, project_id=project_id)
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/create-chat-session",
|
||||
json=req.model_dump(),
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_persona_with_files_chat_no_error(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Upload files, attach them to a persona, wait for processing,
|
||||
then send a chat message. Verify no error is returned."""
|
||||
|
||||
# Upload files (creates UserFile records)
|
||||
text_file = create_test_text_file(
|
||||
"The secret project codename is NIGHTINGALE. "
|
||||
"It was started in 2024 by the Advanced Research division."
|
||||
)
|
||||
file_descriptors, error = FileManager.upload_files(
|
||||
files=[("nightingale_brief.txt", text_file)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert not error, f"File upload failed: {error}"
|
||||
assert len(file_descriptors) == 1
|
||||
|
||||
user_file_id = file_descriptors[0]["user_file_id"]
|
||||
|
||||
# Wait for file processing
|
||||
_poll_file_statuses([user_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create persona with the file attached
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Nightingale Agent",
|
||||
description="Agent with secret file",
|
||||
system_prompt="You are a helpful assistant with access to uploaded files.",
|
||||
user_file_ids=[user_file_id],
|
||||
)
|
||||
|
||||
# Verify persona has the file
|
||||
detail = _get_persona_detail(persona.id, admin_user)
|
||||
persona_file_ids = [str(f["id"]) for f in detail.get("user_files", [])]
|
||||
assert user_file_id in persona_file_ids
|
||||
|
||||
# Chat with the persona
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
description="Test persona file context",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="What is the secret project codename?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None, f"Chat should succeed, got error: {response.error}"
|
||||
assert len(response.full_message) > 0, "Response should not be empty"
|
||||
|
||||
|
||||
def test_persona_without_files_still_works(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""A persona with no attached files should still chat normally."""
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Blank Agent",
|
||||
description="No files attached",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=persona.id,
|
||||
description="Test blank persona",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="Hello, how are you?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert len(response.full_message) > 0
|
||||
|
||||
|
||||
def test_persona_files_override_project_files(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When a custom persona (with its own files) is used inside a project,
|
||||
the persona's files take precedence — the project's files are invisible.
|
||||
|
||||
We verify this by putting different content in project vs persona files
|
||||
and checking which content the model responds with."""
|
||||
|
||||
# Upload persona file
|
||||
persona_file = create_test_text_file("The persona's secret word is ALBATROSS.")
|
||||
persona_fds, err1 = FileManager.upload_files(
|
||||
files=[("persona_secret.txt", persona_file)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert not err1
|
||||
persona_user_file_id = persona_fds[0]["user_file_id"]
|
||||
|
||||
# Create a project and upload project files
|
||||
project = ProjectManager.create(
|
||||
name="Precedence Test Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
project_files = [
|
||||
("project_secret.txt", b"The project's secret word is FLAMINGO."),
|
||||
]
|
||||
ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=project_files,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for persona file processing
|
||||
_poll_file_statuses([persona_user_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create persona with persona file
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="Override Agent",
|
||||
description="Persona with its own files",
|
||||
system_prompt="You are a helpful assistant. Answer using the files.",
|
||||
user_file_ids=[persona_user_file_id],
|
||||
)
|
||||
|
||||
# Create chat session inside the project but using the custom persona
|
||||
session_data = _create_chat_session_with_project(
|
||||
persona_id=persona.id,
|
||||
project_id=project.id,
|
||||
user=admin_user,
|
||||
)
|
||||
chat_session_id = session_data["chat_session_id"]
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session_id,
|
||||
message="What is the secret word?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None, f"Chat should succeed, got error: {response.error}"
|
||||
# The persona's file should be what the model sees, not the project's
|
||||
message_lower = response.full_message.lower()
|
||||
assert "albatross" in message_lower, (
|
||||
"Response should reference the persona file's secret word (ALBATROSS), "
|
||||
f"but got: {response.full_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_default_persona_in_project_uses_project_files(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When the default persona (id=0) is used inside a project,
|
||||
the project's files should be used for context."""
|
||||
project = ProjectManager.create(
|
||||
name="Default Persona Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
project_files = [
|
||||
("project_info.txt", b"The project mascot is a PANGOLIN."),
|
||||
]
|
||||
upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=project_files,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(upload_result.user_files) == 1
|
||||
|
||||
# Wait for project file processing
|
||||
project_file_id = str(upload_result.user_files[0].id)
|
||||
_poll_file_statuses([project_file_id], admin_user, timeout=120)
|
||||
|
||||
# Create chat session inside project using default persona (id=0)
|
||||
session_data = _create_chat_session_with_project(
|
||||
persona_id=0,
|
||||
project_id=project.id,
|
||||
user=admin_user,
|
||||
)
|
||||
chat_session_id = session_data["chat_session_id"]
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session_id,
|
||||
message="What is the project mascot?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
assert "pangolin" in response.full_message.lower(), (
|
||||
"Response should reference the project file content (PANGOLIN), "
|
||||
f"but got: {response.full_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_custom_persona_no_files_in_project_ignores_project(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""A custom persona with NO files, used inside a project with files,
|
||||
should NOT see the project's files. The project is purely organizational.
|
||||
|
||||
We verify by asking about content only in the project file and checking
|
||||
the model does NOT reference it."""
|
||||
|
||||
project = ProjectManager.create(
|
||||
name="Ignored Project",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=[("project_only.txt", b"The project secret is CAPYBARA.")],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Custom persona with no files
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="No Files Agent",
|
||||
description="No files, project is irrelevant",
|
||||
system_prompt=(
|
||||
"You are a helpful assistant. If you do not have information "
|
||||
"to answer a question, say 'I do not have that information.'"
|
||||
),
|
||||
)
|
||||
|
||||
session_data = _create_chat_session_with_project(
|
||||
persona_id=persona.id,
|
||||
project_id=project.id,
|
||||
user=admin_user,
|
||||
)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=session_data["chat_session_id"],
|
||||
message="What is the project secret?",
|
||||
user_performing_action=admin_user,
|
||||
mock_llm_response="I do not have that information.",
|
||||
)
|
||||
|
||||
assert response.error is None
|
||||
# With mock_llm_response the model echoes back the mock.
|
||||
# The key assertion is that the chat completes without error
|
||||
# and the project file content is not injected.
|
||||
assert len(response.full_message) > 0
|
||||
@@ -3,7 +3,6 @@ from fastapi import HTTPException
|
||||
|
||||
import onyx.auth.users as users
|
||||
from onyx.auth.users import verify_email_domain
|
||||
from onyx.configs.constants import AuthType
|
||||
|
||||
|
||||
def test_verify_email_domain_allows_case_insensitive_match(
|
||||
@@ -36,37 +35,3 @@ def test_verify_email_domain_invalid_email_format(
|
||||
verify_email_domain("userexample.com") # missing '@'
|
||||
assert exc.value.status_code == 400
|
||||
assert "Email is not valid" in exc.value.detail
|
||||
|
||||
|
||||
def test_verify_email_domain_rejects_plus_addressing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
verify_email_domain("user+tag@gmail.com")
|
||||
assert exc.value.status_code == 400
|
||||
assert "'+'" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_verify_email_domain_allows_plus_for_onyx_app(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
|
||||
|
||||
# Should not raise for onyx.app domain
|
||||
verify_email_domain("user+tag@onyx.app")
|
||||
|
||||
|
||||
def test_verify_email_domain_rejects_googlemail(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
verify_email_domain("user@googlemail.com")
|
||||
assert exc.value.status_code == 400
|
||||
assert "gmail.com" in str(exc.value.detail)
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_project_sync_queued_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
check_for_user_file_project_sync,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_single_user_file_project_sync,
|
||||
)
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
|
||||
|
||||
|
||||
def _build_redis_mock_with_lock() -> tuple[MagicMock, MagicMock]:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
return redis_client, lock
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks."
|
||||
"get_user_file_project_sync_queue_depth"
|
||||
)
|
||||
@patch("onyx.background.celery.tasks.user_file_processing.tasks.get_redis_client")
|
||||
def test_check_for_user_file_project_sync_applies_queue_backpressure(
|
||||
mock_get_redis_client: MagicMock,
|
||||
mock_get_queue_depth: MagicMock,
|
||||
) -> None:
|
||||
redis_client, lock = _build_redis_mock_with_lock()
|
||||
mock_get_redis_client.return_value = redis_client
|
||||
mock_get_queue_depth.return_value = USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH + 1
|
||||
|
||||
task_app = MagicMock()
|
||||
with patch.object(check_for_user_file_project_sync, "app", task_app):
|
||||
check_for_user_file_project_sync.run(tenant_id="test-tenant")
|
||||
|
||||
task_app.send_task.assert_not_called()
|
||||
lock.release.assert_called_once()
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks."
|
||||
"enqueue_user_file_project_sync_task"
|
||||
)
|
||||
@patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks."
|
||||
"get_user_file_project_sync_queue_depth"
|
||||
)
|
||||
@patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks."
|
||||
"get_session_with_current_tenant"
|
||||
)
|
||||
@patch("onyx.background.celery.tasks.user_file_processing.tasks.get_redis_client")
|
||||
def test_check_for_user_file_project_sync_skips_duplicates(
|
||||
mock_get_redis_client: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_queue_depth: MagicMock,
|
||||
mock_enqueue: MagicMock,
|
||||
) -> None:
|
||||
redis_client, lock = _build_redis_mock_with_lock()
|
||||
mock_get_redis_client.return_value = redis_client
|
||||
mock_get_queue_depth.return_value = 0
|
||||
|
||||
user_file_id_one = uuid4()
|
||||
user_file_id_two = uuid4()
|
||||
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.all.return_value = [
|
||||
user_file_id_one,
|
||||
user_file_id_two,
|
||||
]
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
mock_enqueue.side_effect = [True, False]
|
||||
|
||||
task_app = MagicMock()
|
||||
with patch.object(check_for_user_file_project_sync, "app", task_app):
|
||||
check_for_user_file_project_sync.run(tenant_id="test-tenant")
|
||||
|
||||
assert mock_enqueue.call_count == 2
|
||||
lock.release.assert_called_once()
|
||||
|
||||
|
||||
def test_enqueue_user_file_project_sync_task_sets_guard_and_expiry() -> None:
|
||||
redis_client = MagicMock()
|
||||
redis_client.set.return_value = True
|
||||
celery_app = MagicMock()
|
||||
user_file_id = str(uuid4())
|
||||
|
||||
enqueued = enqueue_user_file_project_sync_task(
|
||||
celery_app=celery_app,
|
||||
redis_client=redis_client,
|
||||
user_file_id=user_file_id,
|
||||
tenant_id="test-tenant",
|
||||
priority=OnyxCeleryPriority.HIGHEST,
|
||||
)
|
||||
|
||||
assert enqueued is True
|
||||
redis_client.set.assert_called_once_with(
|
||||
_user_file_project_sync_queued_key(user_file_id),
|
||||
1,
|
||||
nx=True,
|
||||
ex=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
|
||||
)
|
||||
celery_app.send_task.assert_called_once_with(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": user_file_id, "tenant_id": "test-tenant"},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=OnyxCeleryPriority.HIGHEST,
|
||||
expires=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
|
||||
)
|
||||
|
||||
|
||||
def test_enqueue_user_file_project_sync_task_rolls_back_guard_on_publish_failure() -> (
|
||||
None
|
||||
):
|
||||
redis_client = MagicMock()
|
||||
redis_client.set.return_value = True
|
||||
celery_app = MagicMock()
|
||||
celery_app.send_task.side_effect = RuntimeError("publish failed")
|
||||
|
||||
user_file_id = str(uuid4())
|
||||
with pytest.raises(RuntimeError):
|
||||
enqueue_user_file_project_sync_task(
|
||||
celery_app=celery_app,
|
||||
redis_client=redis_client,
|
||||
user_file_id=user_file_id,
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
redis_client.delete.assert_called_once_with(
|
||||
_user_file_project_sync_queued_key(user_file_id)
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.background.celery.tasks.user_file_processing.tasks.get_redis_client")
|
||||
def test_process_single_user_file_project_sync_clears_queued_guard_on_pickup(
|
||||
mock_get_redis_client: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis_client.return_value = redis_client
|
||||
|
||||
user_file_id = str(uuid4())
|
||||
process_single_user_file_project_sync.run(
|
||||
user_file_id=user_file_id,
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
redis_client.delete.assert_called_once_with(
|
||||
_user_file_project_sync_queued_key(user_file_id)
|
||||
)
|
||||
467
backend/tests/unit/onyx/chat/test_context_files.py
Normal file
467
backend/tests/unit/onyx/chat/test_context_files.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""Tests for the unified context file extraction logic (Phase 5).
|
||||
|
||||
Covers:
|
||||
- _resolve_context_user_files: precedence rule (custom persona supersedes project)
|
||||
- _extract_context_files: all-or-nothing context window fit check
|
||||
- Search filter / search_usage determination in the caller
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.process_message import _extract_context_files
|
||||
from onyx.chat.process_message import _resolve_context_user_files
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_user_file(
|
||||
token_count: int = 100,
|
||||
name: str = "file.txt",
|
||||
file_id: str | None = None,
|
||||
) -> MagicMock:
|
||||
uf = MagicMock()
|
||||
uf.id = UUID(file_id) if file_id else uuid4()
|
||||
uf.file_id = str(uf.id)
|
||||
uf.name = name
|
||||
uf.token_count = token_count
|
||||
return uf
|
||||
|
||||
|
||||
def _make_persona(
|
||||
persona_id: int,
|
||||
user_files: list | None = None,
|
||||
) -> MagicMock:
|
||||
persona = MagicMock()
|
||||
persona.id = persona_id
|
||||
persona.user_files = user_files or []
|
||||
return persona
|
||||
|
||||
|
||||
def _make_in_memory_file(
|
||||
file_id: str,
|
||||
content: str = "hello world",
|
||||
file_type: ChatFileType = ChatFileType.PLAIN_TEXT,
|
||||
filename: str = "file.txt",
|
||||
) -> InMemoryChatFile:
|
||||
return InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=content.encode("utf-8"),
|
||||
file_type=file_type,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _resolve_context_user_files
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestResolveContextUserFiles:
|
||||
"""Precedence rule: custom persona fully supersedes project."""
|
||||
|
||||
def test_custom_persona_with_files_returns_persona_files(self) -> None:
|
||||
persona_files = [_make_user_file(), _make_user_file()]
|
||||
persona = _make_persona(persona_id=42, user_files=persona_files)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = _resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == persona_files
|
||||
|
||||
def test_custom_persona_without_files_returns_empty(self) -> None:
|
||||
"""Custom persona with no files should NOT fall through to project."""
|
||||
persona = _make_persona(persona_id=42, user_files=[])
|
||||
db_session = MagicMock()
|
||||
|
||||
result = _resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_custom_persona_none_files_returns_empty(self) -> None:
|
||||
"""Custom persona with user_files=None should NOT fall through."""
|
||||
persona = _make_persona(persona_id=42, user_files=None)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = _resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("onyx.chat.process_message.get_user_files_from_project")
|
||||
def test_default_persona_in_project_returns_project_files(
|
||||
self, mock_get_files: MagicMock
|
||||
) -> None:
|
||||
project_files = [_make_user_file(), _make_user_file()]
|
||||
mock_get_files.return_value = project_files
|
||||
persona = _make_persona(persona_id=DEFAULT_PERSONA_ID)
|
||||
user_id = uuid4()
|
||||
db_session = MagicMock()
|
||||
|
||||
result = _resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
assert result == project_files
|
||||
mock_get_files.assert_called_once_with(
|
||||
project_id=99, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
def test_default_persona_no_project_returns_empty(self) -> None:
|
||||
persona = _make_persona(persona_id=DEFAULT_PERSONA_ID)
|
||||
db_session = MagicMock()
|
||||
|
||||
result = _resolve_context_user_files(
|
||||
persona=persona, project_id=None, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("onyx.chat.process_message.get_user_files_from_project")
|
||||
def test_custom_persona_without_files_ignores_project(
|
||||
self, mock_get_files: MagicMock
|
||||
) -> None:
|
||||
"""Even with a project_id, custom persona means project is invisible."""
|
||||
persona = _make_persona(persona_id=7, user_files=[])
|
||||
db_session = MagicMock()
|
||||
|
||||
result = _resolve_context_user_files(
|
||||
persona=persona, project_id=99, user_id=uuid4(), db_session=db_session
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_get_files.assert_not_called()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _extract_context_files
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestExtractContextFiles:
|
||||
"""All-or-nothing context window fit check."""
|
||||
|
||||
def test_empty_user_files_returns_empty(self) -> None:
|
||||
db_session = MagicMock()
|
||||
result = _extract_context_files(
|
||||
user_files=[],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=db_session,
|
||||
)
|
||||
assert result.file_texts == []
|
||||
assert result.image_files == []
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.uncapped_token_count is None
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_files_fit_in_context_are_loaded(self, mock_load: MagicMock) -> None:
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=100, file_id=file_id)
|
||||
mock_load.return_value = [
|
||||
_make_in_memory_file(file_id=file_id, content="file content")
|
||||
]
|
||||
|
||||
result = _extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.file_texts == ["file content"]
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.total_token_count == 100
|
||||
assert len(result.file_metadata) == 1
|
||||
assert result.file_metadata[0].file_id == file_id
|
||||
|
||||
def test_files_overflow_context_not_loaded(self) -> None:
|
||||
"""When aggregate tokens exceed 60% of available window, nothing is loaded."""
|
||||
uf = _make_user_file(token_count=7000)
|
||||
|
||||
result = _extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.file_texts == []
|
||||
assert result.image_files == []
|
||||
assert result.use_as_search_filter is True
|
||||
assert result.uncapped_token_count == 7000
|
||||
assert result.total_token_count == 0
|
||||
|
||||
def test_overflow_boundary_exact(self) -> None:
|
||||
"""Token count exactly at the 60% boundary should trigger overflow."""
|
||||
# Available = (10000 - 0) * 0.6 = 6000. Tokens = 6000 → >= threshold.
|
||||
uf = _make_user_file(token_count=6000)
|
||||
|
||||
result = _extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_just_under_boundary_loads(self, mock_load: MagicMock) -> None:
|
||||
"""Token count just under the 60% boundary should load files."""
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=5999, file_id=file_id)
|
||||
mock_load.return_value = [_make_in_memory_file(file_id=file_id, content="data")]
|
||||
|
||||
result = _extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is False
|
||||
assert result.file_texts == ["data"]
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_multiple_files_aggregate_check(self, mock_load: MagicMock) -> None:
|
||||
"""Multiple small files that individually fit but collectively overflow."""
|
||||
files = [_make_user_file(token_count=2500) for _ in range(3)]
|
||||
# 3 * 2500 = 7500 > 6000 threshold
|
||||
|
||||
result = _extract_context_files(
|
||||
user_files=files,
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
assert result.file_texts == []
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_reserved_tokens_reduce_available_space(self, mock_load: MagicMock) -> None:
|
||||
"""Reserved tokens shrink the available window."""
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=3000, file_id=file_id)
|
||||
# Available = (10000 - 5000) * 0.6 = 3000. Tokens = 3000 → overflow.
|
||||
|
||||
result = _extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=5000,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is True
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("onyx.chat.process_message.load_in_memory_chat_files")
|
||||
def test_image_files_are_extracted(self, mock_load: MagicMock) -> None:
|
||||
file_id = str(uuid4())
|
||||
uf = _make_user_file(token_count=50, file_id=file_id)
|
||||
mock_load.return_value = [
|
||||
InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=b"\x89PNG",
|
||||
file_type=ChatFileType.IMAGE,
|
||||
filename="photo.png",
|
||||
)
|
||||
]
|
||||
|
||||
result = _extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert len(result.image_files) == 1
|
||||
assert result.image_files[0].file_id == file_id
|
||||
assert result.file_texts == []
|
||||
assert result.total_token_count == 50
|
||||
|
||||
@patch("onyx.chat.process_message.DISABLE_VECTOR_DB", True)
|
||||
def test_overflow_with_vector_db_disabled_provides_tool_metadata(self) -> None:
|
||||
"""When vector DB is disabled, overflow produces FileToolMetadata."""
|
||||
uf = _make_user_file(token_count=7000, name="bigfile.txt")
|
||||
|
||||
result = _extract_context_files(
|
||||
user_files=[uf],
|
||||
llm_max_context_window=10000,
|
||||
reserved_token_count=0,
|
||||
db_session=MagicMock(),
|
||||
)
|
||||
|
||||
assert result.use_as_search_filter is False
|
||||
assert len(result.file_metadata_for_tool) == 1
|
||||
assert result.file_metadata_for_tool[0].filename == "bigfile.txt"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Search filter + search_usage determination
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSearchFilterDetermination:
|
||||
"""Verify that the caller correctly determines search_project_id,
|
||||
search_persona_id, and search_usage based on the extraction result
|
||||
and the precedence rule.
|
||||
|
||||
These test the logic inline in handle_stream_message_objects by
|
||||
exercising the same conditionals in isolation.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _determine_search_params(
|
||||
persona_id: int,
|
||||
has_persona_files: bool, # noqa: ARG004
|
||||
project_id: int | None,
|
||||
use_as_search_filter: bool,
|
||||
file_texts: list[str] | None = None,
|
||||
uncapped_token_count: int | None = None,
|
||||
) -> dict:
|
||||
"""Replicate the search filter + search_usage logic from
|
||||
handle_stream_message_objects."""
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
|
||||
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
|
||||
search_project_id: int | None = None
|
||||
search_persona_id: int | None = None
|
||||
|
||||
if use_as_search_filter:
|
||||
if is_custom_persona:
|
||||
search_persona_id = persona_id
|
||||
else:
|
||||
search_project_id = project_id
|
||||
|
||||
search_usage = SearchToolUsage.AUTO
|
||||
if not is_custom_persona and project_id:
|
||||
has_context_files = bool(uncapped_token_count)
|
||||
files_loaded_in_context = bool(file_texts)
|
||||
|
||||
if use_as_search_filter:
|
||||
search_usage = SearchToolUsage.ENABLED
|
||||
elif files_loaded_in_context or not has_context_files:
|
||||
search_usage = SearchToolUsage.DISABLED
|
||||
|
||||
return {
|
||||
"search_project_id": search_project_id,
|
||||
"search_persona_id": search_persona_id,
|
||||
"search_usage": search_usage,
|
||||
}
|
||||
|
||||
def test_custom_persona_files_fit_no_filter(self) -> None:
|
||||
"""Custom persona, files fit → no search filter, AUTO."""
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
|
||||
result = self._determine_search_params(
|
||||
persona_id=42,
|
||||
has_persona_files=True,
|
||||
project_id=99,
|
||||
use_as_search_filter=False,
|
||||
file_texts=["content"],
|
||||
uncapped_token_count=100,
|
||||
)
|
||||
assert result["search_project_id"] is None
|
||||
assert result["search_persona_id"] is None
|
||||
assert result["search_usage"] == SearchToolUsage.AUTO
|
||||
|
||||
def test_custom_persona_files_overflow_persona_filter(self) -> None:
|
||||
"""Custom persona, files overflow → persona_id filter, AUTO."""
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
|
||||
result = self._determine_search_params(
|
||||
persona_id=42,
|
||||
has_persona_files=True,
|
||||
project_id=99,
|
||||
use_as_search_filter=True,
|
||||
)
|
||||
assert result["search_persona_id"] == 42
|
||||
assert result["search_project_id"] is None
|
||||
assert result["search_usage"] == SearchToolUsage.AUTO
|
||||
|
||||
def test_custom_persona_no_files_no_project_leak(self) -> None:
|
||||
"""Custom persona (no files) in project → nothing leaks from project."""
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
|
||||
result = self._determine_search_params(
|
||||
persona_id=42,
|
||||
has_persona_files=False,
|
||||
project_id=99,
|
||||
use_as_search_filter=False,
|
||||
)
|
||||
assert result["search_project_id"] is None
|
||||
assert result["search_persona_id"] is None
|
||||
assert result["search_usage"] == SearchToolUsage.AUTO
|
||||
|
||||
def test_default_persona_project_files_fit_disables_search(self) -> None:
|
||||
"""Default persona, project files fit → DISABLED."""
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
|
||||
result = self._determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
has_persona_files=False,
|
||||
project_id=99,
|
||||
use_as_search_filter=False,
|
||||
file_texts=["content"],
|
||||
uncapped_token_count=100,
|
||||
)
|
||||
assert result["search_project_id"] is None
|
||||
assert result["search_usage"] == SearchToolUsage.DISABLED
|
||||
|
||||
def test_default_persona_project_files_overflow_enables_search(self) -> None:
|
||||
"""Default persona, project files overflow → ENABLED + project_id filter."""
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
|
||||
result = self._determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
has_persona_files=False,
|
||||
project_id=99,
|
||||
use_as_search_filter=True,
|
||||
uncapped_token_count=7000,
|
||||
)
|
||||
assert result["search_project_id"] == 99
|
||||
assert result["search_persona_id"] is None
|
||||
assert result["search_usage"] == SearchToolUsage.ENABLED
|
||||
|
||||
def test_default_persona_no_project_auto(self) -> None:
|
||||
"""Default persona, no project → AUTO."""
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
|
||||
result = self._determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
has_persona_files=False,
|
||||
project_id=None,
|
||||
use_as_search_filter=False,
|
||||
)
|
||||
assert result["search_project_id"] is None
|
||||
assert result["search_usage"] == SearchToolUsage.AUTO
|
||||
|
||||
def test_default_persona_project_no_files_disables_search(self) -> None:
|
||||
"""Default persona in project with no files → DISABLED."""
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
|
||||
result = self._determine_search_params(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
has_persona_files=False,
|
||||
project_id=99,
|
||||
use_as_search_filter=False,
|
||||
file_texts=None,
|
||||
uncapped_token_count=None,
|
||||
)
|
||||
assert result["search_usage"] == SearchToolUsage.DISABLED
|
||||
@@ -7,10 +7,10 @@ from onyx.chat.llm_loop import _try_fallback_tool_extraction
|
||||
from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import ContextFileMetadata
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import ChatFileType
|
||||
@@ -76,18 +76,18 @@ def create_tool_response(
|
||||
|
||||
def create_project_files(
|
||||
num_files: int = 0, num_images: int = 0, tokens_per_file: int = 100
|
||||
) -> ExtractedProjectFiles:
|
||||
"""Helper to create ExtractedProjectFiles for testing."""
|
||||
project_file_texts = [f"Project file {i} content" for i in range(num_files)]
|
||||
project_file_metadata = [
|
||||
ProjectFileMetadata(
|
||||
) -> ExtractedContextFiles:
|
||||
"""Helper to create ExtractedContextFiles for testing."""
|
||||
file_texts = [f"Project file {i} content" for i in range(num_files)]
|
||||
file_metadata = [
|
||||
ContextFileMetadata(
|
||||
file_id=f"file_{i}",
|
||||
filename=f"file_{i}.txt",
|
||||
file_content=f"Project file {i} content",
|
||||
)
|
||||
for i in range(num_files)
|
||||
]
|
||||
project_image_files = [
|
||||
image_files = [
|
||||
ChatLoadedFile(
|
||||
file_id=f"image_{i}",
|
||||
content=b"",
|
||||
@@ -98,13 +98,13 @@ def create_project_files(
|
||||
)
|
||||
for i in range(num_images)
|
||||
]
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=project_file_texts,
|
||||
project_image_files=project_image_files,
|
||||
project_as_filter=False,
|
||||
return ExtractedContextFiles(
|
||||
file_texts=file_texts,
|
||||
image_files=image_files,
|
||||
use_as_search_filter=False,
|
||||
total_token_count=num_files * tokens_per_file,
|
||||
project_file_metadata=project_file_metadata,
|
||||
project_uncapped_token_count=num_files * tokens_per_file,
|
||||
file_metadata=file_metadata,
|
||||
uncapped_token_count=num_files * tokens_per_file,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentBase
|
||||
from onyx.connectors.models import TextSection
|
||||
|
||||
|
||||
def _minimal_doc_kwargs(metadata: dict) -> dict:
|
||||
return {
|
||||
"id": "test-doc",
|
||||
"sections": [TextSection(text="hello", link="http://example.com")],
|
||||
"source": DocumentSource.NOT_APPLICABLE,
|
||||
"semantic_identifier": "Test Doc",
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
|
||||
def test_int_values_coerced_to_str() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"count": 42}))
|
||||
assert doc.metadata == {"count": "42"}
|
||||
|
||||
|
||||
def test_float_values_coerced_to_str() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"score": 3.14}))
|
||||
assert doc.metadata == {"score": "3.14"}
|
||||
|
||||
|
||||
def test_bool_values_coerced_to_str() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"active": True}))
|
||||
assert doc.metadata == {"active": "True"}
|
||||
|
||||
|
||||
def test_list_of_ints_coerced_to_list_of_str() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"ids": [1, 2, 3]}))
|
||||
assert doc.metadata == {"ids": ["1", "2", "3"]}
|
||||
|
||||
|
||||
def test_list_of_mixed_types_coerced_to_list_of_str() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"tags": ["a", 1, True, 2.5]}))
|
||||
assert doc.metadata == {"tags": ["a", "1", "True", "2.5"]}
|
||||
|
||||
|
||||
def test_list_of_dicts_coerced_to_list_of_str() -> None:
|
||||
raw = {"nested": [{"key": "val"}, {"key2": "val2"}]}
|
||||
doc = Document(**_minimal_doc_kwargs(raw))
|
||||
assert doc.metadata == {"nested": ["{'key': 'val'}", "{'key2': 'val2'}"]}
|
||||
|
||||
|
||||
def test_dict_value_coerced_to_str() -> None:
|
||||
raw = {"info": {"inner_key": "inner_val"}}
|
||||
doc = Document(**_minimal_doc_kwargs(raw))
|
||||
assert doc.metadata == {"info": "{'inner_key': 'inner_val'}"}
|
||||
|
||||
|
||||
def test_none_value_coerced_to_str() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"empty": None}))
|
||||
assert doc.metadata == {"empty": "None"}
|
||||
|
||||
|
||||
def test_already_valid_str_values_unchanged() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"key": "value"}))
|
||||
assert doc.metadata == {"key": "value"}
|
||||
|
||||
|
||||
def test_already_valid_list_of_str_unchanged() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"tags": ["a", "b", "c"]}))
|
||||
assert doc.metadata == {"tags": ["a", "b", "c"]}
|
||||
|
||||
|
||||
def test_empty_metadata_unchanged() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({}))
|
||||
assert doc.metadata == {}
|
||||
|
||||
|
||||
def test_mixed_metadata_values() -> None:
|
||||
raw = {
|
||||
"str_val": "hello",
|
||||
"int_val": 99,
|
||||
"list_val": [1, "two", 3.0],
|
||||
"dict_val": {"nested": True},
|
||||
}
|
||||
doc = Document(**_minimal_doc_kwargs(raw))
|
||||
assert doc.metadata == {
|
||||
"str_val": "hello",
|
||||
"int_val": "99",
|
||||
"list_val": ["1", "two", "3.0"],
|
||||
"dict_val": "{'nested': True}",
|
||||
}
|
||||
|
||||
|
||||
def test_coercion_works_on_base_class() -> None:
|
||||
kwargs = _minimal_doc_kwargs({"count": 42})
|
||||
kwargs.pop("source")
|
||||
kwargs.pop("id")
|
||||
doc = DocumentBase(**kwargs)
|
||||
assert doc.metadata == {"count": "42"}
|
||||
@@ -1,52 +0,0 @@
|
||||
import pytest
|
||||
from office365.graph_client import AzureEnvironment # type: ignore[import-untyped]
|
||||
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment
|
||||
|
||||
|
||||
def test_resolve_global_defaults() -> None:
|
||||
env = resolve_microsoft_environment(
|
||||
"https://graph.microsoft.com", "https://login.microsoftonline.com"
|
||||
)
|
||||
assert env.environment == AzureEnvironment.Global
|
||||
assert env.sharepoint_domain_suffix == "sharepoint.com"
|
||||
|
||||
|
||||
def test_resolve_gcc_high() -> None:
|
||||
env = resolve_microsoft_environment(
|
||||
"https://graph.microsoft.us", "https://login.microsoftonline.us"
|
||||
)
|
||||
assert env.environment == AzureEnvironment.USGovernmentHigh
|
||||
assert env.graph_host == "https://graph.microsoft.us"
|
||||
assert env.authority_host == "https://login.microsoftonline.us"
|
||||
assert env.sharepoint_domain_suffix == "sharepoint.us"
|
||||
|
||||
|
||||
def test_resolve_dod() -> None:
|
||||
env = resolve_microsoft_environment(
|
||||
"https://dod-graph.microsoft.us", "https://login.microsoftonline.us"
|
||||
)
|
||||
assert env.environment == AzureEnvironment.USGovernmentDoD
|
||||
assert env.sharepoint_domain_suffix == "sharepoint.us"
|
||||
|
||||
|
||||
def test_trailing_slashes_are_stripped() -> None:
|
||||
env = resolve_microsoft_environment(
|
||||
"https://graph.microsoft.us/", "https://login.microsoftonline.us/"
|
||||
)
|
||||
assert env.environment == AzureEnvironment.USGovernmentHigh
|
||||
|
||||
|
||||
def test_mismatched_authority_raises() -> None:
|
||||
with pytest.raises(ConnectorValidationError, match="inconsistent"):
|
||||
resolve_microsoft_environment(
|
||||
"https://graph.microsoft.us", "https://login.microsoftonline.com"
|
||||
)
|
||||
|
||||
|
||||
def test_unknown_graph_host_raises() -> None:
|
||||
with pytest.raises(ConnectorValidationError, match="Unsupported"):
|
||||
resolve_microsoft_environment(
|
||||
"https://graph.example.com", "https://login.example.com"
|
||||
)
|
||||
@@ -97,7 +97,6 @@ class TestScimDALUserMappings:
|
||||
assert model_attrs(added_obj) == {
|
||||
"external_id": "ext-1",
|
||||
"user_id": user_id,
|
||||
"scim_username": None,
|
||||
}
|
||||
|
||||
def test_delete_user_mapping(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user