mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-14 12:12:40 +00:00
Compare commits
31 Commits
xlsx-parse
...
v3.0.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
923e0691aa | ||
|
|
b232e2a771 | ||
|
|
c3ebfeda2f | ||
|
|
6a28dfedb1 | ||
|
|
a123ec083d | ||
|
|
f448f1274d | ||
|
|
d12f8b94aa | ||
|
|
355fe2ff2c | ||
|
|
8ec5423a0c | ||
|
|
79b615db46 | ||
|
|
98756bccd4 | ||
|
|
418f84ccdf | ||
|
|
d37756a884 | ||
|
|
9cdc92441b | ||
|
|
b8ed30644a | ||
|
|
d7d19e5a28 | ||
|
|
948650829d | ||
|
|
b6e689be0f | ||
|
|
85877408c8 | ||
|
|
c00df75c79 | ||
|
|
6352c9a09e | ||
|
|
3065f70d7d | ||
|
|
4befbc49dc | ||
|
|
ae9679e8c4 | ||
|
|
ea0ddee5c8 | ||
|
|
2826405dd2 | ||
|
|
8485bf4368 | ||
|
|
7bb52b0839 | ||
|
|
85a54c01f1 | ||
|
|
e4577bd564 | ||
|
|
f150a7b940 |
2
.github/workflows/deployment.yml
vendored
2
.github/workflows/deployment.yml
vendored
@@ -151,7 +151,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
4
.github/workflows/pr-playwright-tests.yml
vendored
4
.github/workflows/pr-playwright-tests.yml
vendored
@@ -471,7 +471,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -710,7 +710,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
|
||||
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Setup Terraform
|
||||
uses: hashicorp/setup-terraform@5e8dbf3c6d9deaf4193ca7a8fb23f2ac83bb6c85 # ratchet:hashicorp/setup-terraform@v4.0.0
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6
|
||||
with: # zizmor: ignore[cache-poisoning]
|
||||
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
69
.github/workflows/storybook-deploy.yml
vendored
69
.github/workflows/storybook-deploy.yml
vendored
@@ -1,69 +0,0 @@
|
||||
name: Storybook Deploy
|
||||
env:
|
||||
VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }}
|
||||
VERCEL_PROJECT_ID: prj_sG49mVsA25UsxIPhN2pmBJlikJZM
|
||||
VERCEL_CLI: vercel@50.14.1
|
||||
VERCEL_TOKEN: ${{ secrets.VERCEL_TOKEN }}
|
||||
|
||||
concurrency:
|
||||
group: storybook-deploy-production
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "web/lib/opal/**"
|
||||
- "web/src/refresh-components/**"
|
||||
- "web/.storybook/**"
|
||||
- "web/package.json"
|
||||
- "web/package-lock.json"
|
||||
permissions:
|
||||
contents: read
|
||||
jobs:
|
||||
Deploy-Storybook:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: web
|
||||
run: npm ci
|
||||
|
||||
- name: Build Storybook
|
||||
working-directory: web
|
||||
run: npm run storybook:build
|
||||
|
||||
- name: Deploy to Vercel (Production)
|
||||
working-directory: web
|
||||
run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: Deploy-Storybook
|
||||
if: always() && needs.Deploy-Storybook.result == 'failure'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
sparse-checkout: .github/actions/slack-notify
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: "• Deploy-Storybook"
|
||||
title: "🚨 Storybook Deploy Failed"
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
3
.vscode/env_template.txt
vendored
3
.vscode/env_template.txt
vendored
@@ -7,9 +7,6 @@
|
||||
|
||||
|
||||
AUTH_TYPE=basic
|
||||
# Recommended for basic auth - used for signing password reset and verification tokens
|
||||
# Generate a secure value with: openssl rand -hex 32
|
||||
USER_AUTH_SECRET=""
|
||||
DEV_MODE=true
|
||||
|
||||
|
||||
|
||||
@@ -143,7 +143,6 @@ COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
|
||||
COPY --chown=onyx:onyx ./scripts/reencrypt_secrets.py /app/scripts/reencrypt_secrets.py
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
|
||||
|
||||
# Run Craft template setup at build time when ENABLE_CRAFT=true
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
"""add hierarchy_node_by_connector_credential_pair table
|
||||
|
||||
Revision ID: b5c4d7e8f9a1
|
||||
Revises: a3b8d9e2f1c4
|
||||
Create Date: 2026-03-04
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "b5c4d7e8f9a1"
|
||||
down_revision = "a3b8d9e2f1c4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"hierarchy_node_by_connector_credential_pair",
|
||||
sa.Column("hierarchy_node_id", sa.Integer(), nullable=False),
|
||||
sa.Column("connector_id", sa.Integer(), nullable=False),
|
||||
sa.Column("credential_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["hierarchy_node_id"],
|
||||
["hierarchy_node.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_id", "credential_id"],
|
||||
[
|
||||
"connector_credential_pair.connector_id",
|
||||
"connector_credential_pair.credential_id",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("hierarchy_node_id", "connector_id", "credential_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
"hierarchy_node_by_connector_credential_pair",
|
||||
["connector_id", "credential_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
table_name="hierarchy_node_by_connector_credential_pair",
|
||||
)
|
||||
op.drop_table("hierarchy_node_by_connector_credential_pair")
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy import text
|
||||
from alembic import op
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -21,52 +22,59 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
if MULTI_TENANT:
|
||||
|
||||
# Create the read-only db user if it does not already exist.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
|
||||
# the user is created in the standard migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
|
||||
@@ -9,15 +9,12 @@ from onyx.access.access import (
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_groups
|
||||
from onyx.access.access import collect_user_file_access
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.access.utils import prefix_external_group
|
||||
from onyx.access.utils import prefix_user_group
|
||||
from onyx.db.document import get_document_sources
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -119,68 +116,6 @@ def _get_access_for_documents(
|
||||
return access_map
|
||||
|
||||
|
||||
def _collect_user_file_group_names(user_file: UserFile) -> set[str]:
|
||||
"""Extract user-group names from the already-loaded Persona.groups
|
||||
relationships on a UserFile (skipping deleted personas)."""
|
||||
groups: set[str] = set()
|
||||
for persona in user_file.assistants:
|
||||
if persona.deleted:
|
||||
continue
|
||||
for group in persona.groups:
|
||||
groups.add(group.name)
|
||||
return groups
|
||||
|
||||
|
||||
def get_access_for_user_files_impl(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""EE version: extends the MIT user file ACL with user group names
|
||||
from personas shared via user groups.
|
||||
|
||||
Uses a single DB query (via fetch_user_files_with_access_relationships)
|
||||
that eagerly loads both the MIT-needed and EE-needed relationships.
|
||||
|
||||
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
user_files = fetch_user_files_with_access_relationships(
|
||||
user_file_ids, db_session, eager_load_groups=True
|
||||
)
|
||||
return build_access_for_user_files_impl(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files_impl(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""EE version: works on pre-loaded UserFile objects.
|
||||
Expects Persona.groups to be eagerly loaded.
|
||||
|
||||
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
result: dict[str, DocumentAccess] = {}
|
||||
for user_file in user_files:
|
||||
if user_file.user is None:
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
is_public=True,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
continue
|
||||
|
||||
emails, is_public = collect_user_file_access(user_file)
|
||||
group_names = _collect_user_file_group_names(user_file)
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=list(emails),
|
||||
user_groups=list(group_names),
|
||||
is_public=is_public,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import jwt
|
||||
@@ -21,13 +20,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
# All the Auth flows are valid for EE version, but warn about deprecated 'disabled'
|
||||
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if raw_auth_type == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Using 'basic' instead. Please update your configuration."
|
||||
)
|
||||
# All the Auth flows are valid for EE version
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.db.models import HierarchyNode
|
||||
|
||||
|
||||
def _build_hierarchy_access_filter(
|
||||
user_email: str,
|
||||
user_email: str | None,
|
||||
external_group_ids: list[str],
|
||||
) -> ColumnElement[bool]:
|
||||
"""Build SQLAlchemy filter for hierarchy node access.
|
||||
@@ -43,7 +43,7 @@ def _build_hierarchy_access_filter(
|
||||
def _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str,
|
||||
user_email: str | None,
|
||||
external_group_ids: list[str],
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,6 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.persona import mark_persona_user_files_for_sync
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
|
||||
@@ -27,9 +26,7 @@ def update_persona_access(
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
needs_sync = False
|
||||
if is_public is not None:
|
||||
needs_sync = True
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
@@ -38,7 +35,6 @@ def update_persona_access(
|
||||
# and a non-empty list means "replace with these shares".
|
||||
|
||||
if user_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -58,7 +54,6 @@ def update_persona_access(
|
||||
)
|
||||
|
||||
if group_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -68,7 +63,3 @@ def update_persona_access(
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
# When sharing changes, user file ACLs need to be updated in the vector DB
|
||||
if needs_sync:
|
||||
mark_persona_user_files_for_sync(persona_id, db_session)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.connectors.jira.utils import build_jira_client
|
||||
@@ -9,107 +11,102 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_ATLASSIAN_ACCOUNT_TYPE = "atlassian"
|
||||
_GROUP_MEMBER_PAGE_SIZE = 50
|
||||
|
||||
def _get_jira_group_members_email(
|
||||
# The GET /group/member endpoint was introduced in Jira 6.0.
|
||||
# Jira versions older than 6.0 do not have group management REST APIs at all.
|
||||
_MIN_JIRA_VERSION_FOR_GROUP_MEMBER = "6.0"
|
||||
|
||||
|
||||
def _fetch_group_member_page(
|
||||
jira_client: JIRA,
|
||||
group_name: str,
|
||||
) -> list[str]:
|
||||
"""Get all member emails for a Jira group.
|
||||
start_at: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch a single page from the non-deprecated GET /group/member endpoint.
|
||||
|
||||
Filters out app accounts (bots, integrations) and only returns real user emails.
|
||||
The old GET /group endpoint (used by jira_client.group_members()) is deprecated
|
||||
and decommissioned in Jira Server 10.3+. This uses the replacement endpoint
|
||||
directly via the library's internal _get_json helper, following the same pattern
|
||||
as enhanced_search_ids / bulk_fetch_issues in connector.py.
|
||||
|
||||
There is an open PR to the library to switch to this endpoint since last year:
|
||||
https://github.com/pycontribs/jira/pull/2356
|
||||
so once it is merged and released, we can switch to using the library function.
|
||||
"""
|
||||
emails: list[str] = []
|
||||
|
||||
try:
|
||||
# group_members returns an OrderedDict of account_id -> member_info
|
||||
members = jira_client.group_members(group=group_name)
|
||||
|
||||
if not members:
|
||||
logger.warning(f"No members found for group {group_name}")
|
||||
return emails
|
||||
|
||||
for account_id, member_info in members.items():
|
||||
# member_info is a dict with keys like 'fullname', 'email', 'active'
|
||||
email = member_info.get("email")
|
||||
|
||||
# Skip "hidden" emails - these are typically app accounts
|
||||
if email and email != "hidden":
|
||||
emails.append(email)
|
||||
else:
|
||||
# For cloud, we might need to fetch user details separately
|
||||
try:
|
||||
user = jira_client.user(id=account_id)
|
||||
|
||||
# Skip app accounts (bots, integrations, etc.)
|
||||
if hasattr(user, "accountType") and user.accountType == "app":
|
||||
logger.info(
|
||||
f"Skipping app account {account_id} for group {group_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
if hasattr(user, "emailAddress") and user.emailAddress:
|
||||
emails.append(user.emailAddress)
|
||||
else:
|
||||
logger.warning(f"User {account_id} has no email address")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch email for user {account_id} in group {group_name}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
|
||||
return emails
|
||||
return jira_client._get_json(
|
||||
"group/member",
|
||||
params={
|
||||
"groupname": group_name,
|
||||
"includeInactiveUsers": "false",
|
||||
"startAt": start_at,
|
||||
"maxResults": _GROUP_MEMBER_PAGE_SIZE,
|
||||
},
|
||||
)
|
||||
except JIRAError as e:
|
||||
if e.status_code == 404:
|
||||
raise RuntimeError(
|
||||
f"GET /group/member returned 404 for group '{group_name}'. "
|
||||
f"This endpoint requires Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}+. "
|
||||
f"If you are running a self-hosted Jira instance, please upgrade "
|
||||
f"to at least Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}."
|
||||
) from e
|
||||
raise
|
||||
|
||||
|
||||
def _build_group_member_email_map(
|
||||
def _get_group_member_emails(
|
||||
jira_client: JIRA,
|
||||
) -> dict[str, set[str]]:
|
||||
"""Build a map of group names to member emails."""
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get all member emails for a single Jira group.
|
||||
|
||||
try:
|
||||
# Get all groups from Jira - returns a list of group name strings
|
||||
group_names = jira_client.groups()
|
||||
Uses the non-deprecated GET /group/member endpoint which returns full user
|
||||
objects including accountType, so we can filter out app/customer accounts
|
||||
without making separate user() calls.
|
||||
"""
|
||||
emails: set[str] = set()
|
||||
start_at = 0
|
||||
|
||||
if not group_names:
|
||||
logger.warning("No groups found in Jira")
|
||||
return group_member_emails
|
||||
while True:
|
||||
try:
|
||||
page = _fetch_group_member_page(jira_client, group_name, start_at)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
members: list[dict[str, Any]] = page.get("values", [])
|
||||
for member in members:
|
||||
account_type = member.get("accountType")
|
||||
# On Jira DC < 9.0, accountType is absent; include those users.
|
||||
# On Cloud / DC 9.0+, filter to real user accounts only.
|
||||
if account_type is not None and account_type != _ATLASSIAN_ACCOUNT_TYPE:
|
||||
continue
|
||||
|
||||
member_emails = _get_jira_group_members_email(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
if member_emails:
|
||||
group_member_emails[group_name] = set(member_emails)
|
||||
logger.debug(
|
||||
f"Found {len(member_emails)} members for group {group_name}"
|
||||
)
|
||||
email = member.get("emailAddress")
|
||||
if email:
|
||||
emails.add(email)
|
||||
else:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
logger.warning(
|
||||
f"Atlassian user {member.get('accountId', 'unknown')} "
|
||||
f"in group {group_name} has no visible email address"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building group member email map: {e}")
|
||||
if page.get("isLast", True) or not members:
|
||||
break
|
||||
start_at += len(members)
|
||||
|
||||
return group_member_emails
|
||||
return emails
|
||||
|
||||
|
||||
def jira_group_sync(
|
||||
tenant_id: str, # noqa: ARG001
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""
|
||||
Sync Jira groups and their members.
|
||||
"""Sync Jira groups and their members, yielding one group at a time.
|
||||
|
||||
This function fetches all groups from Jira and yields ExternalUserGroup
|
||||
objects containing the group ID and member emails.
|
||||
Streams group-by-group rather than accumulating all groups in memory.
|
||||
"""
|
||||
jira_base_url = cc_pair.connector.connector_specific_config.get("jira_base_url", "")
|
||||
scoped_token = cc_pair.connector.connector_specific_config.get(
|
||||
@@ -130,12 +127,26 @@ def jira_group_sync(
|
||||
scoped_token=scoped_token,
|
||||
)
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(jira_client=jira_client)
|
||||
if not group_member_email_map:
|
||||
raise ValueError(f"No groups with members found for cc_pair_id={cc_pair.id}")
|
||||
group_names = jira_client.groups()
|
||||
if not group_names:
|
||||
raise ValueError(f"No groups found for cc_pair_id={cc_pair.id}")
|
||||
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
yield ExternalUserGroup(
|
||||
id=group_id,
|
||||
user_emails=list(group_member_emails),
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
continue
|
||||
|
||||
member_emails = _get_group_member_emails(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
if not member_emails:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(member_emails)} members for group {group_name}")
|
||||
yield ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(member_emails),
|
||||
)
|
||||
|
||||
@@ -14,91 +14,67 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache(maxsize=2)
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_trimmed_key(key: str) -> bytes:
|
||||
encoded_key = key.encode()
|
||||
key_length = len(encoded_key)
|
||||
if key_length < 16:
|
||||
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
|
||||
elif key_length > 32:
|
||||
key = key[:32]
|
||||
elif key_length not in (16, 24, 32):
|
||||
valid_lengths = [16, 24, 32]
|
||||
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
|
||||
|
||||
# Trim to the largest valid AES key size that fits
|
||||
valid_lengths = [32, 24, 16]
|
||||
for size in valid_lengths:
|
||||
if key_length >= size:
|
||||
return encoded_key[:size]
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
return encoded_key
|
||||
|
||||
|
||||
def _encrypt_string(input_str: str, key: str | None = None) -> bytes:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
return input_str.encode()
|
||||
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
iv = urandom(16)
|
||||
padder = padding.PKCS7(algorithms.AES.block_size).padder()
|
||||
padded_data = padder.update(input_str.encode()) + padder.finalize()
|
||||
|
||||
cipher = Cipher(algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend())
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
return iv + encrypted_data
|
||||
|
||||
|
||||
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
return input_bytes.decode()
|
||||
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
try:
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
|
||||
cipher = Cipher(
|
||||
algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend()
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
|
||||
return decrypted_data.decode()
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
if key is not None:
|
||||
# Explicit key was provided — don't fall back silently
|
||||
raise
|
||||
# Read path: attempt raw UTF-8 decode as a fallback for legacy data.
|
||||
# Does NOT handle data encrypted with a different key — that
|
||||
# ciphertext is not valid UTF-8 and will raise below.
|
||||
logger.warning(
|
||||
"AES decryption failed — falling back to raw decode. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
)
|
||||
try:
|
||||
return input_bytes.decode()
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError(
|
||||
"Data is not valid UTF-8 — likely encrypted with a different key. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
) from None
|
||||
return decrypted_data.decode()
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(input_str: str, key: str | None = None) -> bytes:
|
||||
def encrypt_string_to_bytes(input_str: str) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
)
|
||||
return versioned_encryption_fn(input_str, key=key)
|
||||
return versioned_encryption_fn(input_str)
|
||||
|
||||
|
||||
def decrypt_bytes_to_string(input_bytes: bytes, key: str | None = None) -> str:
|
||||
def decrypt_bytes_to_string(input_bytes: bytes) -> str:
|
||||
versioned_decryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_decrypt_bytes"
|
||||
)
|
||||
return versioned_decryption_fn(input_bytes, key=key)
|
||||
return versioned_decryption_fn(input_bytes)
|
||||
|
||||
|
||||
def test_encryption() -> None:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
@@ -11,7 +12,6 @@ from onyx.db.document import get_access_info_for_document
|
||||
from onyx.db.document import get_access_info_for_documents
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -132,61 +132,19 @@ def get_access_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
versioned_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "get_access_for_user_files_impl"
|
||||
user_files = (
|
||||
db_session.query(UserFile)
|
||||
.options(joinedload(UserFile.user)) # Eager load the user relationship
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.all()
|
||||
)
|
||||
return versioned_fn(user_file_ids, db_session)
|
||||
|
||||
|
||||
def get_access_for_user_files_impl(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
user_files = fetch_user_files_with_access_relationships(user_file_ids, db_session)
|
||||
return build_access_for_user_files_impl(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""Compute access from pre-loaded UserFile objects (with relationships).
|
||||
Callers must ensure UserFile.user, Persona.users, and Persona.user are
|
||||
eagerly loaded (and Persona.groups for the EE path)."""
|
||||
versioned_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "build_access_for_user_files_impl"
|
||||
)
|
||||
return versioned_fn(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files_impl(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
result: dict[str, DocumentAccess] = {}
|
||||
for user_file in user_files:
|
||||
emails, is_public = collect_user_file_access(user_file)
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=list(emails),
|
||||
return {
|
||||
str(user_file.id): DocumentAccess.build(
|
||||
user_emails=[user_file.user.email] if user_file.user else [],
|
||||
user_groups=[],
|
||||
is_public=is_public,
|
||||
is_public=True if user_file.user is None else False,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def collect_user_file_access(user_file: UserFile) -> tuple[set[str], bool]:
|
||||
"""Collect all user emails that should have access to this user file.
|
||||
Includes the owner plus any users who have access via shared personas.
|
||||
Returns (emails, is_public)."""
|
||||
emails: set[str] = {user_file.user.email}
|
||||
is_public = False
|
||||
for persona in user_file.assistants:
|
||||
if persona.deleted:
|
||||
continue
|
||||
if persona.is_public:
|
||||
is_public = True
|
||||
if persona.user_id is not None and persona.user:
|
||||
emails.add(persona.user.email)
|
||||
for shared_user in persona.users:
|
||||
emails.add(shared_user.email)
|
||||
return emails, is_public
|
||||
for user_file in user_files
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
@@ -146,22 +145,10 @@ def is_user_admin(user: User) -> bool:
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
"""Log warnings for AUTH_TYPE issues.
|
||||
|
||||
This only runs on app startup not during migrations/scripts.
|
||||
"""
|
||||
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
|
||||
if raw_auth_type == "cloud":
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
raise ValueError(
|
||||
"'cloud' is not a valid auth type for self-hosted deployments."
|
||||
f"{AUTH_TYPE.value} is not a valid auth type for self-hosted deployments."
|
||||
)
|
||||
if raw_auth_type == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Using 'basic' instead. Please update your configuration."
|
||||
)
|
||||
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
|
||||
@@ -115,6 +115,8 @@ def _extract_from_batch(
|
||||
for item in doc_list:
|
||||
if isinstance(item, HierarchyNode):
|
||||
hierarchy_nodes.append(item)
|
||||
if item.raw_node_id not in ids:
|
||||
ids[item.raw_node_id] = None
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
failed_id = _get_failure_id(item)
|
||||
if failed_id:
|
||||
@@ -123,7 +125,8 @@ def _extract_from_batch(
|
||||
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
|
||||
)
|
||||
else:
|
||||
ids[item.id] = item.parent_hierarchy_raw_node_id
|
||||
parent_raw = getattr(item, "parent_hierarchy_raw_node_id", None)
|
||||
ids[item.id] = parent_raw
|
||||
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
|
||||
|
||||
|
||||
@@ -189,7 +192,9 @@ def extract_ids_from_runnable_connector(
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
all_raw_id_to_parent.update(batch_ids)
|
||||
for k, v in batch_ids.items():
|
||||
if v is not None or k not in all_raw_id_to_parent:
|
||||
all_raw_id_to_parent[k] = v
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
|
||||
if callback:
|
||||
|
||||
@@ -40,7 +40,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
@@ -290,14 +289,6 @@ def _run_hierarchy_extraction(
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Cache in Redis for fast ancestor resolution
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node) for node in upserted_nodes
|
||||
|
||||
@@ -48,15 +48,10 @@ from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import HierarchyNode as DBHierarchyNode
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.tag import delete_orphan_tags__no_commit
|
||||
@@ -65,7 +60,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import evict_hierarchy_nodes_from_cache
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
@@ -585,12 +579,11 @@ def connector_pruning_generator_task(
|
||||
source = cc_pair.connector.source
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes: list[DBHierarchyNode] = []
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
@@ -599,14 +592,6 @@ def connector_pruning_generator_task(
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in upserted_nodes
|
||||
@@ -622,6 +607,7 @@ def connector_pruning_generator_task(
|
||||
f"hierarchy nodes for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
|
||||
# and bulk-update documents, mirroring the docfetching resolution
|
||||
_resolve_and_update_document_parents(
|
||||
@@ -678,43 +664,6 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
|
||||
# --- Hierarchy node pruning ---
|
||||
live_node_ids = {n.id for n in upserted_nodes}
|
||||
stale_removed = remove_stale_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
live_hierarchy_node_ids=live_node_ids,
|
||||
commit=True,
|
||||
)
|
||||
deleted_raw_ids = delete_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
reparented_nodes = reparent_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
if deleted_raw_ids:
|
||||
evict_hierarchy_nodes_from_cache(redis_client, source, deleted_raw_ids)
|
||||
if reparented_nodes:
|
||||
reparented_cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in reparented_nodes
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client, source, reparented_cache_entries
|
||||
)
|
||||
if stale_removed or deleted_raw_ids or reparented_nodes:
|
||||
task_logger.info(
|
||||
f"Hierarchy node pruning: cc_pair={cc_pair_id} "
|
||||
f"stale_entries_removed={stale_removed} "
|
||||
f"nodes_deleted={len(deleted_raw_ids)} "
|
||||
f"nodes_reparented={len(reparented_nodes)}"
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Pruning exceptioned: cc_pair={cc_pair_id} "
|
||||
|
||||
@@ -12,9 +12,9 @@ from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import build_access_for_user_files
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
@@ -43,9 +43,7 @@ from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -56,7 +54,6 @@ from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAd
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
def _as_uuid(value: str | UUID) -> UUID:
|
||||
@@ -794,12 +791,11 @@ def project_sync_user_file_impl(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_files = fetch_user_files_with_access_relationships(
|
||||
[user_file_id],
|
||||
db_session,
|
||||
eager_load_groups=global_version.is_ee_version(),
|
||||
)
|
||||
user_file = user_files[0] if user_files else None
|
||||
user_file = db_session.execute(
|
||||
select(UserFile)
|
||||
.where(UserFile.id == _as_uuid(user_file_id))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
).scalar_one_or_none()
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - User file not found id={user_file_id}"
|
||||
@@ -827,21 +823,12 @@ def project_sync_user_file_impl(
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
|
||||
|
||||
file_id_str = str(user_file.id)
|
||||
access_map = build_access_for_user_files([user_file])
|
||||
access = access_map.get(file_id_str)
|
||||
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=file_id_str,
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=(
|
||||
VespaDocumentFields(access=access)
|
||||
if access is not None
|
||||
else None
|
||||
),
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=project_ids,
|
||||
personas=persona_ids,
|
||||
|
||||
@@ -45,7 +45,6 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
@@ -588,14 +587,6 @@ def connector_document_extraction(
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Cache in Redis for fast ancestor resolution during doc processing
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
cache_entries = [
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -55,7 +54,6 @@ from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.jsonriver import Parser
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
@@ -1011,7 +1009,6 @@ def run_llm_step_pkt_generator(
|
||||
)
|
||||
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]] = {}
|
||||
arg_parsers: dict[int, Parser] = {}
|
||||
reasoning_start = False
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
@@ -1218,14 +1215,7 @@ def run_llm_step_pkt_generator(
|
||||
yield from _close_reasoning_if_active()
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
# maybe_emit depends and update being called first and attaching the delta
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
yield from maybe_emit_argument_delta(
|
||||
tool_calls_in_progress=id_to_tool_call_map,
|
||||
tool_call_delta=tool_call_delta,
|
||||
placement=_current_placement(),
|
||||
parsers=arg_parsers,
|
||||
)
|
||||
|
||||
# Flush any tail text buffered while checking for split "<function_calls" markers.
|
||||
filtered_content_tail = xml_tool_call_content_filter.flush()
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
from onyx.llm.model_response import ChatCompletionDeltaToolCall
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.tools.built_in_tools import TOOL_NAME_TO_CLASS
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _get_tool_class(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
) -> Type[Tool] | None:
|
||||
"""Look up the Tool subclass for a streaming tool call delta."""
|
||||
tool_name = tool_calls_in_progress.get(tool_call_delta.index, {}).get("name")
|
||||
if not tool_name:
|
||||
return None
|
||||
return TOOL_NAME_TO_CLASS.get(tool_name)
|
||||
|
||||
|
||||
def maybe_emit_argument_delta(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
placement: Placement,
|
||||
parsers: dict[int, Parser],
|
||||
) -> Generator[Packet, None, None]:
|
||||
"""Emit decoded tool-call argument deltas to the frontend.
|
||||
|
||||
Uses a ``jsonriver.Parser`` per tool-call index to incrementally parse
|
||||
the JSON argument string and extract only the newly-appended content
|
||||
for each string-valued argument.
|
||||
|
||||
NOTE: Non-string arguments (numbers, booleans, null, arrays, objects)
|
||||
are skipped — they are available in the final tool-call kickoff packet.
|
||||
|
||||
``parsers`` is a mutable dict keyed by tool-call index. A new
|
||||
``Parser`` is created automatically for each new index.
|
||||
"""
|
||||
tool_cls = _get_tool_class(tool_calls_in_progress, tool_call_delta)
|
||||
if not tool_cls or not tool_cls.should_emit_argument_deltas():
|
||||
return
|
||||
|
||||
fn = tool_call_delta.function
|
||||
delta_fragment = fn.arguments if fn else None
|
||||
if not delta_fragment:
|
||||
return
|
||||
|
||||
idx = tool_call_delta.index
|
||||
if idx not in parsers:
|
||||
parsers[idx] = Parser()
|
||||
parser = parsers[idx]
|
||||
|
||||
deltas = parser.feed(delta_fragment)
|
||||
|
||||
argument_deltas: dict[str, str] = {}
|
||||
for delta in deltas:
|
||||
if isinstance(delta, dict):
|
||||
for key, value in delta.items():
|
||||
if isinstance(value, str):
|
||||
argument_deltas[key] = argument_deltas.get(key, "") + value
|
||||
|
||||
if not argument_deltas:
|
||||
return
|
||||
|
||||
tc_data = tool_calls_in_progress[tool_call_delta.index]
|
||||
yield Packet(
|
||||
placement=placement,
|
||||
obj=ToolCallArgumentDelta(
|
||||
tool_type=tc_data.get("name", ""),
|
||||
argument_deltas=argument_deltas,
|
||||
),
|
||||
)
|
||||
@@ -68,10 +68,6 @@ FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# Maximum upload size for a single user file (chat/projects) in MB.
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
|
||||
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
@@ -96,12 +92,19 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
#####
|
||||
# Auth Configs
|
||||
#####
|
||||
# Silently default to basic - warnings/errors logged in verify_auth_setting()
|
||||
# which only runs on app startup, not during migrations/scripts
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if _auth_type_str in [auth_type.value for auth_type in AuthType]:
|
||||
# Upgrades users from disabled auth to basic auth and shows warning.
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
|
||||
if _auth_type_str == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Defaulting to 'basic'. Please update your configuration. "
|
||||
"Your existing data will be migrated automatically."
|
||||
)
|
||||
_auth_type_str = AuthType.BASIC.value
|
||||
try:
|
||||
AUTH_TYPE = AuthType(_auth_type_str)
|
||||
else:
|
||||
except ValueError:
|
||||
logger.error(f"Invalid AUTH_TYPE: {_auth_type_str}. Defaulting to 'basic'.")
|
||||
AUTH_TYPE = AuthType.BASIC
|
||||
|
||||
PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 8))
|
||||
@@ -204,12 +207,6 @@ JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
|
||||
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
|
||||
|
||||
if AUTH_TYPE == AuthType.BASIC and not USER_AUTH_SECRET:
|
||||
logger.warning(
|
||||
"USER_AUTH_SECRET is not set. This is required for secure password reset "
|
||||
"and email verification tokens. Please set USER_AUTH_SECRET in production."
|
||||
)
|
||||
|
||||
# Duration (in seconds) for which the FastAPI Users JWT token remains valid in the user's browser.
|
||||
# By default, this is set to match the Redis expiry time for consistency.
|
||||
AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
|
||||
|
||||
@@ -44,7 +44,6 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
from onyx.db.credentials import update_credential_json
|
||||
from onyx.db.models import User
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import unwrap_str
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.server.documents.models import GoogleAppCredentials
|
||||
from onyx.server.documents.models import GoogleServiceAccountKey
|
||||
@@ -90,7 +89,7 @@ def _get_current_oauth_user(creds: OAuthCredentials, source: DocumentSource) ->
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = unwrap_str(get_kv_store().load(KV_CRED_KEY.format(str(credential_id))))
|
||||
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Google Drive Connector callback does not match expected"
|
||||
@@ -179,9 +178,7 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id),
|
||||
{"value": params.get("state", [None])[0]},
|
||||
encrypt=True,
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
)
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from office365.runtime.queries.client_query import ClientQuery # type: ignore[i
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
|
||||
@@ -268,6 +269,32 @@ class SizeCapExceeded(Exception):
|
||||
"""Exception raised when the size cap is exceeded."""
|
||||
|
||||
|
||||
def _log_and_raise_for_status(response: requests.Response) -> None:
|
||||
"""Log the response text and raise for status."""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception:
|
||||
logger.error(f"HTTP request failed: {response.text}")
|
||||
raise
|
||||
|
||||
|
||||
GRAPH_INVALID_REQUEST_CODE = "invalidRequest"
|
||||
|
||||
|
||||
def _is_graph_invalid_request(response: requests.Response) -> bool:
|
||||
"""Return True if the response body is the generic Graph API
|
||||
``{"error": {"code": "invalidRequest", "message": "Invalid request"}}``
|
||||
shape. This particular error has no actionable inner error code and is
|
||||
returned by the site-pages endpoint when a page has a corrupt canvas layout
|
||||
(e.g. duplicate web-part IDs — see SharePoint/sp-dev-docs#8822)."""
|
||||
try:
|
||||
body = response.json()
|
||||
except Exception:
|
||||
return False
|
||||
error = body.get("error", {})
|
||||
return error.get("code") == GRAPH_INVALID_REQUEST_CODE
|
||||
|
||||
|
||||
def load_certificate_from_pfx(pfx_data: bytes, password: str) -> CertificateData | None:
|
||||
"""Load certificate from .pfx file for MSAL authentication"""
|
||||
try:
|
||||
@@ -344,7 +371,7 @@ def _probe_remote_size(url: str, timeout: int) -> int | None:
|
||||
"""Determine remote size using HEAD or a range GET probe. Returns None if unknown."""
|
||||
try:
|
||||
head_resp = requests.head(url, timeout=timeout, allow_redirects=True)
|
||||
head_resp.raise_for_status()
|
||||
_log_and_raise_for_status(head_resp)
|
||||
cl = head_resp.headers.get("Content-Length")
|
||||
if cl and cl.isdigit():
|
||||
return int(cl)
|
||||
@@ -359,7 +386,7 @@ def _probe_remote_size(url: str, timeout: int) -> int | None:
|
||||
timeout=timeout,
|
||||
stream=True,
|
||||
) as range_resp:
|
||||
range_resp.raise_for_status()
|
||||
_log_and_raise_for_status(range_resp)
|
||||
cr = range_resp.headers.get("Content-Range") # e.g., "bytes 0-0/12345"
|
||||
if cr and "/" in cr:
|
||||
total = cr.split("/")[-1]
|
||||
@@ -384,7 +411,7 @@ def _download_with_cap(url: str, timeout: int, cap: int) -> bytes:
|
||||
- Returns the full bytes if the content fits within `cap`.
|
||||
"""
|
||||
with requests.get(url, stream=True, timeout=timeout) as resp:
|
||||
resp.raise_for_status()
|
||||
_log_and_raise_for_status(resp)
|
||||
|
||||
# If the server provides Content-Length, prefer an early decision.
|
||||
cl_header = resp.headers.get("Content-Length")
|
||||
@@ -428,7 +455,7 @@ def _download_via_graph_api(
|
||||
with requests.get(
|
||||
url, headers=headers, stream=True, timeout=REQUEST_TIMEOUT_SECONDS
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
_log_and_raise_for_status(resp)
|
||||
buf = io.BytesIO()
|
||||
for chunk in resp.iter_content(64 * 1024):
|
||||
if not chunk:
|
||||
@@ -1238,26 +1265,135 @@ class SharepointConnector(
|
||||
site.execute_query()
|
||||
site_id = site.id
|
||||
|
||||
page_url: str | None = (
|
||||
f"{self.graph_api_base}/sites/{site_id}" f"/pages/microsoft.graph.sitePage"
|
||||
site_pages_base = (
|
||||
f"{self.graph_api_base}/sites/{site_id}/pages/microsoft.graph.sitePage"
|
||||
)
|
||||
page_url: str | None = site_pages_base
|
||||
params: dict[str, str] | None = {"$expand": "canvasLayout"}
|
||||
total_yielded = 0
|
||||
yielded_ids: set[str] = set()
|
||||
|
||||
while page_url:
|
||||
data = self._graph_api_get_json(page_url, params)
|
||||
try:
|
||||
data = self._graph_api_get_json(page_url, params)
|
||||
except HTTPError as e:
|
||||
if e.response is not None and e.response.status_code == 404:
|
||||
logger.warning(f"Site page not found: {page_url}")
|
||||
break
|
||||
if (
|
||||
e.response is not None
|
||||
and e.response.status_code == 400
|
||||
and _is_graph_invalid_request(e.response)
|
||||
):
|
||||
logger.warning(
|
||||
f"$expand=canvasLayout on the LIST endpoint returned 400 "
|
||||
f"for site {site_descriptor.url}. Falling back to "
|
||||
f"per-page expansion."
|
||||
)
|
||||
yield from self._fetch_site_pages_individually(
|
||||
site_pages_base, start, end, skip_ids=yielded_ids
|
||||
)
|
||||
return
|
||||
raise
|
||||
|
||||
params = None # nextLink already embeds query params
|
||||
|
||||
for page in data.get("value", []):
|
||||
if not _site_page_in_time_window(page, start, end):
|
||||
continue
|
||||
total_yielded += 1
|
||||
page_id = page.get("id")
|
||||
if page_id:
|
||||
yielded_ids.add(page_id)
|
||||
yield page
|
||||
|
||||
page_url = data.get("@odata.nextLink")
|
||||
|
||||
logger.debug(f"Yielded {total_yielded} site pages for {site_descriptor.url}")
|
||||
|
||||
def _fetch_site_pages_individually(
|
||||
self,
|
||||
site_pages_base: str,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
skip_ids: set[str] | None = None,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""Fallback for _fetch_site_pages: list pages without $expand, then
|
||||
expand canvasLayout on each page individually.
|
||||
|
||||
The Graph API's LIST endpoint can return 400 when $expand=canvasLayout
|
||||
is used and *any* page in the site has a corrupt canvas layout (e.g.
|
||||
duplicate web part IDs — see SharePoint/sp-dev-docs#8822). Since the
|
||||
LIST expansion is all-or-nothing, a single bad page poisons the entire
|
||||
response. This method works around it by fetching metadata first, then
|
||||
expanding each page individually so only the broken page loses its
|
||||
canvas content.
|
||||
|
||||
``skip_ids`` contains page IDs already yielded by the caller before the
|
||||
fallback was triggered, preventing duplicates.
|
||||
"""
|
||||
page_url: str | None = site_pages_base
|
||||
total_yielded = 0
|
||||
_skip_ids = skip_ids or set()
|
||||
|
||||
while page_url:
|
||||
try:
|
||||
data = self._graph_api_get_json(page_url)
|
||||
except HTTPError as e:
|
||||
if e.response is not None and e.response.status_code == 404:
|
||||
break
|
||||
raise
|
||||
|
||||
for page in data.get("value", []):
|
||||
if not _site_page_in_time_window(page, start, end):
|
||||
continue
|
||||
|
||||
page_id = page.get("id")
|
||||
if page_id and page_id in _skip_ids:
|
||||
continue
|
||||
|
||||
if not page_id:
|
||||
total_yielded += 1
|
||||
yield page
|
||||
continue
|
||||
|
||||
expanded = self._try_expand_single_page(site_pages_base, page_id, page)
|
||||
total_yielded += 1
|
||||
yield expanded
|
||||
|
||||
page_url = data.get("@odata.nextLink")
|
||||
|
||||
logger.debug(
|
||||
f"Yielded {total_yielded} site pages (per-page expansion fallback)"
|
||||
)
|
||||
|
||||
def _try_expand_single_page(
|
||||
self,
|
||||
site_pages_base: str,
|
||||
page_id: str,
|
||||
fallback_page: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Try to GET a single page with $expand=canvasLayout. On 400, return
|
||||
the metadata-only fallback so the page is still indexed (without canvas
|
||||
content)."""
|
||||
pages_collection = site_pages_base.removesuffix("/microsoft.graph.sitePage")
|
||||
single_url = f"{pages_collection}/{page_id}/microsoft.graph.sitePage"
|
||||
try:
|
||||
return self._graph_api_get_json(single_url, {"$expand": "canvasLayout"})
|
||||
except HTTPError as e:
|
||||
if (
|
||||
e.response is not None
|
||||
and e.response.status_code == 400
|
||||
and _is_graph_invalid_request(e.response)
|
||||
):
|
||||
page_name = fallback_page.get("name", page_id)
|
||||
logger.warning(
|
||||
f"$expand=canvasLayout failed for page '{page_name}' "
|
||||
f"({page_id}). Indexing metadata only."
|
||||
)
|
||||
return fallback_page
|
||||
raise
|
||||
|
||||
def _acquire_token(self) -> dict[str, Any]:
|
||||
"""
|
||||
Acquire token via MSAL
|
||||
@@ -1309,7 +1445,7 @@ class SharepointConnector(
|
||||
access_token = self._get_graph_access_token()
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
continue
|
||||
response.raise_for_status()
|
||||
_log_and_raise_for_status(response)
|
||||
return response.json()
|
||||
except (requests.ConnectionError, requests.Timeout):
|
||||
if attempt < GRAPH_API_MAX_RETRIES:
|
||||
|
||||
@@ -2,10 +2,7 @@
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -13,7 +10,6 @@ from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import HierarchyNode
|
||||
from onyx.db.models import HierarchyNodeByConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -462,7 +458,7 @@ def get_all_hierarchy_nodes_for_source(
|
||||
def _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str, # noqa: ARG001
|
||||
user_email: str | None, # noqa: ARG001
|
||||
external_group_ids: list[str], # noqa: ARG001
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
@@ -489,7 +485,7 @@ def _get_accessible_hierarchy_nodes_for_source(
|
||||
def get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str,
|
||||
user_email: str | None,
|
||||
external_group_ids: list[str],
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
@@ -624,154 +620,3 @@ def update_hierarchy_node_permissions(
|
||||
db_session.flush()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session: Session,
|
||||
hierarchy_node_ids: list[int],
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""Insert rows into HierarchyNodeByConnectorCredentialPair, ignoring conflicts.
|
||||
|
||||
This records that the given cc_pair "owns" these hierarchy nodes. Used by
|
||||
indexing, pruning, and hierarchy-fetching paths.
|
||||
"""
|
||||
if not hierarchy_node_ids:
|
||||
return
|
||||
|
||||
_M = HierarchyNodeByConnectorCredentialPair
|
||||
stmt = pg_insert(_M).values(
|
||||
[
|
||||
{
|
||||
_M.hierarchy_node_id: node_id,
|
||||
_M.connector_id: connector_id,
|
||||
_M.credential_id: credential_id,
|
||||
}
|
||||
for node_id in hierarchy_node_ids
|
||||
]
|
||||
)
|
||||
stmt = stmt.on_conflict_do_nothing()
|
||||
db_session.execute(stmt)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def remove_stale_hierarchy_node_cc_pair_entries(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
live_hierarchy_node_ids: set[int],
|
||||
commit: bool = True,
|
||||
) -> int:
|
||||
"""Delete join-table rows for this cc_pair that are NOT in the live set.
|
||||
|
||||
If ``live_hierarchy_node_ids`` is empty ALL rows for the cc_pair are deleted
|
||||
(i.e. the connector no longer has any hierarchy nodes). Callers that want a
|
||||
no-op when there are no live nodes must guard before calling.
|
||||
|
||||
Returns the number of deleted rows.
|
||||
"""
|
||||
stmt = delete(HierarchyNodeByConnectorCredentialPair).where(
|
||||
HierarchyNodeByConnectorCredentialPair.connector_id == connector_id,
|
||||
HierarchyNodeByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
if live_hierarchy_node_ids:
|
||||
stmt = stmt.where(
|
||||
HierarchyNodeByConnectorCredentialPair.hierarchy_node_id.notin_(
|
||||
live_hierarchy_node_ids
|
||||
)
|
||||
)
|
||||
|
||||
result: CursorResult = db_session.execute(stmt) # type: ignore[assignment]
|
||||
deleted = result.rowcount
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
elif deleted:
|
||||
db_session.flush()
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
def delete_orphaned_hierarchy_nodes(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
commit: bool = True,
|
||||
) -> list[str]:
|
||||
"""Delete hierarchy nodes for a source that have zero cc_pair associations.
|
||||
|
||||
SOURCE-type nodes are excluded (they are synthetic roots).
|
||||
|
||||
Returns the list of raw_node_ids that were deleted (for cache eviction).
|
||||
"""
|
||||
# Find orphaned nodes: no rows in the join table
|
||||
orphan_stmt = (
|
||||
select(HierarchyNode.id, HierarchyNode.raw_node_id)
|
||||
.outerjoin(
|
||||
HierarchyNodeByConnectorCredentialPair,
|
||||
HierarchyNode.id
|
||||
== HierarchyNodeByConnectorCredentialPair.hierarchy_node_id,
|
||||
)
|
||||
.where(
|
||||
HierarchyNode.source == source,
|
||||
HierarchyNode.node_type != HierarchyNodeType.SOURCE,
|
||||
HierarchyNodeByConnectorCredentialPair.hierarchy_node_id.is_(None),
|
||||
)
|
||||
)
|
||||
orphans = db_session.execute(orphan_stmt).all()
|
||||
if not orphans:
|
||||
return []
|
||||
|
||||
orphan_ids = [row[0] for row in orphans]
|
||||
deleted_raw_ids = [row[1] for row in orphans]
|
||||
|
||||
db_session.execute(delete(HierarchyNode).where(HierarchyNode.id.in_(orphan_ids)))
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
db_session.flush()
|
||||
|
||||
return deleted_raw_ids
|
||||
|
||||
|
||||
def reparent_orphaned_hierarchy_nodes(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
commit: bool = True,
|
||||
) -> list[HierarchyNode]:
|
||||
"""Re-parent hierarchy nodes whose parent_id is NULL to the SOURCE node.
|
||||
|
||||
After pruning deletes stale nodes, their former children get parent_id=NULL
|
||||
via the SET NULL cascade. This function points them back to the SOURCE root.
|
||||
|
||||
Returns the reparented HierarchyNode objects (with updated parent_id)
|
||||
so callers can refresh downstream caches.
|
||||
"""
|
||||
source_node = get_source_hierarchy_node(db_session, source)
|
||||
if not source_node:
|
||||
return []
|
||||
|
||||
stmt = select(HierarchyNode).where(
|
||||
HierarchyNode.source == source,
|
||||
HierarchyNode.parent_id.is_(None),
|
||||
HierarchyNode.node_type != HierarchyNodeType.SOURCE,
|
||||
)
|
||||
orphans = list(db_session.execute(stmt).scalars().all())
|
||||
if not orphans:
|
||||
return []
|
||||
|
||||
for node in orphans:
|
||||
node.parent_id = source_node.id
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
db_session.flush()
|
||||
|
||||
return orphans
|
||||
|
||||
@@ -25,6 +25,7 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
@@ -369,9 +370,9 @@ def upsert_llm_provider(
|
||||
def sync_model_configurations(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
models: list[dict],
|
||||
models: list[SyncModelEntry],
|
||||
) -> int:
|
||||
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama).
|
||||
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama, etc.).
|
||||
|
||||
This inserts NEW models from the source API without overwriting existing ones.
|
||||
User preferences (is_visible, max_input_tokens) are preserved for existing models.
|
||||
@@ -379,7 +380,7 @@ def sync_model_configurations(
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider_name: Name of the LLM provider
|
||||
models: List of model dicts with keys: name, display_name, max_input_tokens, supports_image_input
|
||||
models: List of SyncModelEntry objects describing the fetched models
|
||||
|
||||
Returns:
|
||||
Number of new models added
|
||||
@@ -393,21 +394,20 @@ def sync_model_configurations(
|
||||
|
||||
new_count = 0
|
||||
for model in models:
|
||||
model_name = model["name"]
|
||||
if model_name not in existing_names:
|
||||
if model.name not in existing_names:
|
||||
# Insert new model with is_visible=False (user must explicitly enable)
|
||||
supported_flows = [LLMModelFlowType.CHAT]
|
||||
if model.get("supports_image_input", False):
|
||||
if model.supports_image_input:
|
||||
supported_flows.append(LLMModelFlowType.VISION)
|
||||
|
||||
insert_new_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
llm_provider_id=provider.id,
|
||||
model_name=model_name,
|
||||
model_name=model.name,
|
||||
supported_flows=supported_flows,
|
||||
is_visible=False,
|
||||
max_input_tokens=model.get("max_input_tokens"),
|
||||
display_name=model.get("display_name"),
|
||||
max_input_tokens=model.max_input_tokens,
|
||||
display_name=model.display_name,
|
||||
)
|
||||
new_count += 1
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ from sqlalchemy import desc
|
||||
from sqlalchemy import Enum
|
||||
from sqlalchemy import Float
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import ForeignKeyConstraint
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
@@ -37,11 +36,9 @@ from sqlalchemy import Text
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import Mapper
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.types import LargeBinary
|
||||
@@ -120,50 +117,10 @@ class Base(DeclarativeBase):
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
class _EncryptedBase(TypeDecorator):
|
||||
"""Base for encrypted column types that wrap values in SensitiveValue."""
|
||||
|
||||
class EncryptedString(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
_is_json: bool = False
|
||||
|
||||
def wrap_raw(self, value: Any) -> SensitiveValue:
|
||||
"""Encrypt a raw value and wrap it in SensitiveValue.
|
||||
|
||||
Called by the attribute set event so the Python-side type is always
|
||||
SensitiveValue, regardless of whether the value was loaded from the DB
|
||||
or assigned in application code.
|
||||
"""
|
||||
if self._is_json:
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(
|
||||
f"EncryptedJson column expected dict, got {type(value).__name__}"
|
||||
)
|
||||
raw_str = json.dumps(value)
|
||||
else:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(
|
||||
f"EncryptedString column expected str, got {type(value).__name__}"
|
||||
)
|
||||
raw_str = value
|
||||
return SensitiveValue(
|
||||
encrypted_bytes=encrypt_string_to_bytes(raw_str),
|
||||
decrypt_fn=decrypt_bytes_to_string,
|
||||
is_json=self._is_json,
|
||||
)
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
|
||||
class EncryptedString(_EncryptedBase):
|
||||
_is_json: bool = False
|
||||
|
||||
def process_bind_param(
|
||||
self, value: str | SensitiveValue[str] | None, dialect: Dialect # noqa: ARG002
|
||||
@@ -187,9 +144,20 @@ class EncryptedString(_EncryptedBase):
|
||||
)
|
||||
return None
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
class EncryptedJson(_EncryptedBase):
|
||||
_is_json: bool = True
|
||||
|
||||
class EncryptedJson(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(
|
||||
self,
|
||||
@@ -197,7 +165,9 @@ class EncryptedJson(_EncryptedBase):
|
||||
dialect: Dialect, # noqa: ARG002
|
||||
) -> bytes | None:
|
||||
if value is not None:
|
||||
# Handle both raw dicts and SensitiveValue wrappers
|
||||
if isinstance(value, SensitiveValue):
|
||||
# Get raw value for storage
|
||||
value = value.get_value(apply_mask=False)
|
||||
json_str = json.dumps(value)
|
||||
return encrypt_string_to_bytes(json_str)
|
||||
@@ -214,40 +184,14 @@ class EncryptedJson(_EncryptedBase):
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
_REGISTERED_ATTRS: set[str] = set()
|
||||
|
||||
|
||||
@event.listens_for(Mapper, "mapper_configured")
|
||||
def _register_sensitive_value_set_events(
|
||||
mapper: Mapper,
|
||||
class_: type,
|
||||
) -> None:
|
||||
"""Auto-wrap raw values in SensitiveValue when assigned to encrypted columns."""
|
||||
for prop in mapper.column_attrs:
|
||||
for col in prop.columns:
|
||||
if isinstance(col.type, _EncryptedBase):
|
||||
col_type = col.type
|
||||
attr = getattr(class_, prop.key)
|
||||
|
||||
# Guard against double-registration (e.g. if mapper is
|
||||
# re-configured in test setups)
|
||||
attr_key = f"{class_.__qualname__}.{prop.key}"
|
||||
if attr_key in _REGISTERED_ATTRS:
|
||||
continue
|
||||
_REGISTERED_ATTRS.add(attr_key)
|
||||
|
||||
@event.listens_for(attr, "set", retval=True)
|
||||
def _wrap_value(
|
||||
target: Any, # noqa: ARG001
|
||||
value: Any,
|
||||
oldvalue: Any, # noqa: ARG001
|
||||
initiator: Any, # noqa: ARG001
|
||||
_col_type: _EncryptedBase = col_type,
|
||||
) -> Any:
|
||||
if value is not None and not isinstance(value, SensitiveValue):
|
||||
return _col_type.wrap_raw(value)
|
||||
return value
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
|
||||
class NullFilteredString(TypeDecorator):
|
||||
@@ -2426,38 +2370,6 @@ class SyncRecord(Base):
|
||||
)
|
||||
|
||||
|
||||
class HierarchyNodeByConnectorCredentialPair(Base):
|
||||
"""Tracks which cc_pairs reference each hierarchy node.
|
||||
|
||||
During pruning, stale entries are removed for the current cc_pair.
|
||||
Hierarchy nodes with zero remaining entries are then deleted.
|
||||
"""
|
||||
|
||||
__tablename__ = "hierarchy_node_by_connector_credential_pair"
|
||||
|
||||
hierarchy_node_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("hierarchy_node.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
connector_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
credential_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
__table_args__ = (
|
||||
ForeignKeyConstraint(
|
||||
["connector_id", "credential_id"],
|
||||
[
|
||||
"connector_credential_pair.connector_id",
|
||||
"connector_credential_pair.credential_id",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
Index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
"connector_id",
|
||||
"credential_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DocumentByConnectorCredentialPair(Base):
|
||||
"""Represents an indexing of a document by a specific connector / credential pair"""
|
||||
|
||||
|
||||
@@ -205,9 +205,7 @@ def update_persona_access(
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
needs_sync = False
|
||||
if is_public is not None:
|
||||
needs_sync = True
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
@@ -215,7 +213,6 @@ def update_persona_access(
|
||||
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
|
||||
# and a non-empty list means "replace with these shares".
|
||||
if user_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -236,7 +233,6 @@ def update_persona_access(
|
||||
# MIT doesn't support group-based sharing, so we allow clearing (no-op since
|
||||
# there shouldn't be any) but raise an error if trying to add actual groups.
|
||||
if group_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -244,10 +240,6 @@ def update_persona_access(
|
||||
if group_ids:
|
||||
raise NotImplementedError("Onyx MIT does not support group-based sharing")
|
||||
|
||||
# When sharing changes, user file ACLs need to be updated in the vector DB
|
||||
if needs_sync:
|
||||
mark_persona_user_files_for_sync(persona_id, db_session)
|
||||
|
||||
|
||||
def create_update_persona(
|
||||
persona_id: int | None,
|
||||
@@ -859,24 +851,6 @@ def update_personas_display_priority(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_persona_user_files_for_sync(
|
||||
persona_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""When persona sharing changes, mark all of its user files for sync
|
||||
so that their ACLs get updated in the vector DB."""
|
||||
persona = (
|
||||
db_session.query(Persona)
|
||||
.options(selectinload(Persona.user_files))
|
||||
.filter(Persona.id == persona_id)
|
||||
.first()
|
||||
)
|
||||
if not persona:
|
||||
return
|
||||
file_ids = [uf.id for uf in persona.user_files]
|
||||
_mark_files_need_persona_sync(db_session, file_ids)
|
||||
|
||||
|
||||
def _mark_files_need_persona_sync(
|
||||
db_session: Session,
|
||||
user_file_ids: list[UUID],
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
"""Rotate encryption key for all encrypted columns.
|
||||
|
||||
Dynamically discovers all columns using EncryptedString / EncryptedJson,
|
||||
decrypts each value with the old key, and re-encrypts with the current
|
||||
ENCRYPTION_KEY_SECRET.
|
||||
|
||||
The operation is idempotent: rows already encrypted with the current key
|
||||
are skipped. Commits are made in batches so a crash mid-rotation can be
|
||||
safely resumed by re-running.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import LargeBinary
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
|
||||
from onyx.db.models import Base
|
||||
from onyx.db.models import EncryptedJson
|
||||
from onyx.db.models import EncryptedString
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_BATCH_SIZE = 500
|
||||
|
||||
|
||||
def _can_decrypt_with_current_key(data: bytes) -> bool:
|
||||
"""Check if data is already encrypted with the current key.
|
||||
|
||||
Passes the key explicitly so the fallback-to-raw-decode path in
|
||||
_decrypt_bytes is NOT triggered — a clean success/failure signal.
|
||||
"""
|
||||
try:
|
||||
decrypt_bytes_to_string(data, key=ENCRYPTION_KEY_SECRET)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _discover_encrypted_columns() -> list[tuple[type, str, list[str], bool]]:
|
||||
"""Walk all ORM models and find columns using EncryptedString/EncryptedJson.
|
||||
|
||||
Returns list of (ModelClass, column_attr_name, [pk_attr_names], is_json).
|
||||
"""
|
||||
results: list[tuple[type, str, list[str], bool]] = []
|
||||
|
||||
for mapper in Base.registry.mappers:
|
||||
model_cls = mapper.class_
|
||||
pk_names = [col.key for col in mapper.primary_key]
|
||||
|
||||
for prop in mapper.column_attrs:
|
||||
for col in prop.columns:
|
||||
if isinstance(col.type, EncryptedJson):
|
||||
results.append((model_cls, prop.key, pk_names, True))
|
||||
elif isinstance(col.type, EncryptedString):
|
||||
results.append((model_cls, prop.key, pk_names, False))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def rotate_encryption_key(
|
||||
db_session: Session,
|
||||
old_key: str | None,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, int]:
|
||||
"""Decrypt all encrypted columns with old_key and re-encrypt with the current key.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
old_key: The previous encryption key. Pass None or "" if values were
|
||||
not previously encrypted with a key.
|
||||
dry_run: If True, count rows that need rotation without modifying data.
|
||||
|
||||
Returns:
|
||||
Dict of "table.column" -> number of rows re-encrypted (or would be).
|
||||
|
||||
Commits every _BATCH_SIZE rows so that locks are held briefly and progress
|
||||
is preserved on crash. Already-rotated rows are detected and skipped,
|
||||
making the operation safe to re-run.
|
||||
"""
|
||||
if not global_version.is_ee_version():
|
||||
raise RuntimeError("EE mode is not enabled — rotation requires EE encryption.")
|
||||
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
raise RuntimeError(
|
||||
"ENCRYPTION_KEY_SECRET is not set — cannot rotate. "
|
||||
"Set the target encryption key in the environment before running."
|
||||
)
|
||||
|
||||
encrypted_columns = _discover_encrypted_columns()
|
||||
totals: dict[str, int] = {}
|
||||
|
||||
for model_cls, col_name, pk_names, is_json in encrypted_columns:
|
||||
table_name: str = model_cls.__tablename__ # type: ignore[attr-defined]
|
||||
col_attr = getattr(model_cls, col_name)
|
||||
pk_attrs = [getattr(model_cls, pk) for pk in pk_names]
|
||||
|
||||
# Read raw bytes directly, bypassing the TypeDecorator
|
||||
raw_col = col_attr.property.columns[0]
|
||||
|
||||
stmt = select(*pk_attrs, raw_col.cast(LargeBinary)).where(col_attr.is_not(None))
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
reencrypted = 0
|
||||
batch_pending = 0
|
||||
for row in rows:
|
||||
raw_bytes: bytes | None = row[-1]
|
||||
if raw_bytes is None:
|
||||
continue
|
||||
|
||||
if _can_decrypt_with_current_key(raw_bytes):
|
||||
continue
|
||||
|
||||
try:
|
||||
if not old_key:
|
||||
decrypted_str = raw_bytes.decode("utf-8")
|
||||
else:
|
||||
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=old_key)
|
||||
|
||||
# For EncryptedJson, parse back to dict so the TypeDecorator
|
||||
# can json.dumps() it cleanly (avoids double-encoding).
|
||||
value: Any = json.loads(decrypted_str) if is_json else decrypted_str
|
||||
except (ValueError, UnicodeDecodeError) as e:
|
||||
pk_vals = [row[i] for i in range(len(pk_names))]
|
||||
logger.warning(
|
||||
f"Could not decrypt/parse {table_name}.{col_name} "
|
||||
f"row {pk_vals} — skipping: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not dry_run:
|
||||
pk_filters = [pk_attr == row[i] for i, pk_attr in enumerate(pk_attrs)]
|
||||
update_stmt = (
|
||||
update(model_cls).where(*pk_filters).values({col_name: value})
|
||||
)
|
||||
db_session.execute(update_stmt)
|
||||
batch_pending += 1
|
||||
|
||||
if batch_pending >= _BATCH_SIZE:
|
||||
db_session.commit()
|
||||
batch_pending = 0
|
||||
reencrypted += 1
|
||||
|
||||
# Flush remaining rows in this column
|
||||
if batch_pending > 0:
|
||||
db_session.commit()
|
||||
|
||||
if reencrypted > 0:
|
||||
totals[f"{table_name}.{col_name}"] = reencrypted
|
||||
logger.info(
|
||||
f"{'[DRY RUN] Would re-encrypt' if dry_run else 'Re-encrypted'} "
|
||||
f"{reencrypted} value(s) in {table_name}.{col_name}"
|
||||
)
|
||||
|
||||
return totals
|
||||
@@ -3,11 +3,9 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserFile
|
||||
|
||||
@@ -120,31 +118,3 @@ def get_file_ids_by_user_file_ids(
|
||||
) -> list[str]:
|
||||
user_files = db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).all()
|
||||
return [user_file.file_id for user_file in user_files]
|
||||
|
||||
|
||||
def fetch_user_files_with_access_relationships(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
eager_load_groups: bool = False,
|
||||
) -> list[UserFile]:
|
||||
"""Fetch user files with the owner and assistant relationships
|
||||
eagerly loaded (needed for computing access control).
|
||||
|
||||
When eager_load_groups is True, Persona.groups is also loaded so that
|
||||
callers can extract user-group names without a second DB round-trip."""
|
||||
persona_sub_options = [
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.user),
|
||||
]
|
||||
if eager_load_groups:
|
||||
persona_sub_options.append(selectinload(Persona.groups))
|
||||
|
||||
return (
|
||||
db_session.query(UserFile)
|
||||
.options(
|
||||
joinedload(UserFile.user),
|
||||
selectinload(UserFile.assistants).options(*persona_sub_options),
|
||||
)
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import csv
|
||||
import gc
|
||||
import io
|
||||
import json
|
||||
@@ -20,7 +19,6 @@ from zipfile import BadZipFile
|
||||
|
||||
import chardet
|
||||
import openpyxl
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.constants import ONYX_METADATA_FILENAME
|
||||
@@ -354,65 +352,6 @@ def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
return presentation.markdown
|
||||
|
||||
|
||||
def _worksheet_to_matrix(
|
||||
worksheet: Worksheet,
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Converts a singular worksheet to a matrix of values
|
||||
"""
|
||||
rows: list[list[str]] = []
|
||||
for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True):
|
||||
row = ["" if cell is None else str(cell) for cell in worksheet_row]
|
||||
rows.append(row)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def _clean_worksheet_matrix(matrix: list[list[str]]) -> list[list[str]]:
|
||||
"""
|
||||
Cleans a worksheet matrix by removing rows if there are N consecutive empty
|
||||
rows and removing cols if there are M consecutive empty columns
|
||||
"""
|
||||
MAX_EMPTY_ROWS = 2 # Runs longer than this are capped to max_empty; shorter runs are preserved as-is
|
||||
MAX_EMPTY_COLS = 2
|
||||
|
||||
# Row cleanup
|
||||
matrix = _remove_empty_runs(matrix, max_empty=MAX_EMPTY_ROWS)
|
||||
|
||||
# Column cleanup (transpose, clean, transpose back)
|
||||
transposed = list(map(list, zip(*matrix))) if matrix else []
|
||||
transposed = _remove_empty_runs(transposed, max_empty=MAX_EMPTY_COLS)
|
||||
matrix = list(map(list, zip(*transposed))) if transposed else []
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def _remove_empty_runs(
|
||||
rows: list[list[str]],
|
||||
max_empty: int,
|
||||
) -> list[list[str]]:
|
||||
"""Removes entire runs of empty rows when the run length exceeds max_empty.
|
||||
|
||||
Leading and trailing empty rows are always dropped regardless of run length,
|
||||
since there is no adjacent non-empty row to bound the run.
|
||||
"""
|
||||
result: list[list[str]] = []
|
||||
empty_buffer: list[list[str]] = []
|
||||
|
||||
for row in rows:
|
||||
# Check if empty
|
||||
if not any(row):
|
||||
empty_buffer.append(row)
|
||||
else:
|
||||
# Add upto max empty rows onto the result - that's what we allow
|
||||
result.extend(empty_buffer[:max_empty])
|
||||
# Add the new non-empty row
|
||||
result.append(row)
|
||||
empty_buffer = []
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
# TODO: switch back to this approach in a few months when markitdown
|
||||
# fixes their handling of excel files
|
||||
@@ -451,15 +390,30 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
|
||||
)
|
||||
return ""
|
||||
raise
|
||||
raise e
|
||||
|
||||
text_content = []
|
||||
for sheet in workbook.worksheets:
|
||||
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
|
||||
buf = io.StringIO()
|
||||
writer = csv.writer(buf, lineterminator="\n")
|
||||
writer.writerows(sheet_matrix)
|
||||
text_content.append(buf.getvalue().rstrip("\n"))
|
||||
rows = []
|
||||
num_empty_consecutive_rows = 0
|
||||
for row in sheet.iter_rows(min_row=1, values_only=True):
|
||||
row_str = ",".join(str(cell or "") for cell in row)
|
||||
|
||||
# Only add the row if there are any values in the cells
|
||||
if len(row_str) >= len(row):
|
||||
rows.append(row_str)
|
||||
num_empty_consecutive_rows = 0
|
||||
else:
|
||||
num_empty_consecutive_rows += 1
|
||||
|
||||
if num_empty_consecutive_rows > 100:
|
||||
# handle massive excel sheets with mostly empty cells
|
||||
logger.warning(
|
||||
f"Found {num_empty_consecutive_rows} empty rows in {file_name}, skipping rest of file"
|
||||
)
|
||||
break
|
||||
sheet_str = "\n".join(rows)
|
||||
text_content.append(sheet_str)
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
|
||||
|
||||
|
||||
@@ -19,16 +19,12 @@ class OnyxMimeTypes:
|
||||
PLAIN_TEXT_MIME_TYPE,
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"text/x-log",
|
||||
"text/x-config",
|
||||
"text/tab-separated-values",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/xml",
|
||||
"application/x-yaml",
|
||||
"application/yaml",
|
||||
"text/yaml",
|
||||
"text/x-yaml",
|
||||
}
|
||||
DOCUMENT_MIME_TYPES = {
|
||||
PDF_MIME_TYPE,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import abc
|
||||
from typing import cast
|
||||
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
@@ -8,19 +7,6 @@ class KvKeyNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def unwrap_str(val: JSON_ro) -> str:
|
||||
"""Unwrap a string stored as {"value": str} in the encrypted KV store.
|
||||
Also handles legacy plain-string values cached in Redis."""
|
||||
if isinstance(val, dict):
|
||||
try:
|
||||
return cast(str, val["value"])
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Expected dict with 'value' key, got keys: {list(val.keys())}"
|
||||
)
|
||||
return cast(str, val)
|
||||
|
||||
|
||||
class KeyValueStore:
|
||||
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
|
||||
# It's read from the global thread level variable
|
||||
|
||||
@@ -43,6 +43,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
]
|
||||
|
||||
|
||||
@@ -59,6 +60,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
"ollama": "Ollama",
|
||||
LlmProviderNames.OLLAMA_CHAT: "Ollama",
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -109,6 +111,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.VERTEX_AI,
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
|
||||
@@ -11,6 +11,8 @@ OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
|
||||
LM_STUDIO_PROVIDER_NAME = "lm_studio"
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
|
||||
|
||||
LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.llm.well_known_providers.auto_update_service import (
|
||||
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
@@ -47,6 +48,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
OLLAMA_PROVIDER_NAME: [], # Dynamic - fetched from Ollama API
|
||||
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
|
||||
}
|
||||
|
||||
|
||||
@@ -331,6 +333,7 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
BEDROCK_PROVIDER_NAME: "Amazon Bedrock",
|
||||
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
|
||||
OPENROUTER_PROVIDER_NAME: "OpenRouter",
|
||||
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
|
||||
}
|
||||
|
||||
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import re
|
||||
from enum import Enum
|
||||
|
||||
# Matches Slack channel references like <#C097NBWMY8Y> or <#C097NBWMY8Y|channel-name>
|
||||
SLACK_CHANNEL_REF_PATTERN = re.compile(r"<#([A-Z0-9]+)(?:\|([^>]+))?>")
|
||||
|
||||
LIKE_BLOCK_ACTION_ID = "feedback-like"
|
||||
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
|
||||
SHOW_EVERYONE_ACTION_ID = "show-everyone"
|
||||
|
||||
@@ -18,15 +18,18 @@ from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
|
||||
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
|
||||
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.models import ThreadMessage
|
||||
from onyx.onyxbot.slack.utils import get_channel_from_id
|
||||
from onyx.onyxbot.slack.utils import get_channel_name_from_id
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import SlackRateLimiter
|
||||
@@ -41,6 +44,51 @@ srl = SlackRateLimiter()
|
||||
RT = TypeVar("RT") # return type
|
||||
|
||||
|
||||
def resolve_channel_references(
|
||||
message: str,
|
||||
client: WebClient,
|
||||
logger: OnyxLoggingAdapter,
|
||||
) -> tuple[str, list[Tag]]:
|
||||
"""Parse Slack channel references from a message, resolve IDs to names,
|
||||
replace the raw markup with readable #channel-name, and return channel tags
|
||||
for search filtering."""
|
||||
tags: list[Tag] = []
|
||||
channel_matches = SLACK_CHANNEL_REF_PATTERN.findall(message)
|
||||
seen_channel_ids: set[str] = set()
|
||||
|
||||
for channel_id, channel_name_from_markup in channel_matches:
|
||||
if channel_id in seen_channel_ids:
|
||||
continue
|
||||
seen_channel_ids.add(channel_id)
|
||||
|
||||
channel_name = channel_name_from_markup or None
|
||||
|
||||
if not channel_name:
|
||||
try:
|
||||
channel_info = get_channel_from_id(client=client, channel_id=channel_id)
|
||||
channel_name = channel_info.get("name") or None
|
||||
except Exception:
|
||||
logger.warning(f"Failed to resolve channel name for ID: {channel_id}")
|
||||
|
||||
if not channel_name:
|
||||
continue
|
||||
|
||||
# Replace raw Slack markup with readable channel name
|
||||
if channel_name_from_markup:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}|{channel_name_from_markup}>",
|
||||
f"#{channel_name}",
|
||||
)
|
||||
else:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}>",
|
||||
f"#{channel_name}",
|
||||
)
|
||||
tags.append(Tag(tag_key="Channel", tag_value=channel_name))
|
||||
|
||||
return message, tags
|
||||
|
||||
|
||||
def rate_limits(
|
||||
client: WebClient, channel: str, thread_ts: Optional[str]
|
||||
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
|
||||
@@ -157,6 +205,20 @@ def handle_regular_answer(
|
||||
user_message = messages[-1]
|
||||
history_messages = messages[:-1]
|
||||
|
||||
# Resolve any <#CHANNEL_ID> references in the user message to readable
|
||||
# channel names and extract channel tags for search filtering
|
||||
resolved_message, channel_tags = resolve_channel_references(
|
||||
message=user_message.message,
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
user_message = ThreadMessage(
|
||||
message=resolved_message,
|
||||
sender=user_message.sender,
|
||||
role=user_message.role,
|
||||
)
|
||||
|
||||
channel_name, _ = get_channel_name_from_id(
|
||||
client=client,
|
||||
channel_id=channel,
|
||||
@@ -207,6 +269,7 @@ def handle_regular_answer(
|
||||
source_type=None,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=None,
|
||||
tags=channel_tags if channel_tags else None,
|
||||
)
|
||||
|
||||
new_message_request = SendMessageRequest(
|
||||
@@ -231,6 +294,16 @@ def handle_regular_answer(
|
||||
slack_context_str=slack_context_str,
|
||||
)
|
||||
|
||||
# If a channel filter was applied but no results were found, override
|
||||
# the LLM response to avoid hallucinated answers about unindexed channels
|
||||
if channel_tags and not answer.citation_info and not answer.top_documents:
|
||||
channel_names = ", ".join(f"#{tag.tag_value}" for tag in channel_tags)
|
||||
answer.answer = (
|
||||
f"No indexed data found for {channel_names}. "
|
||||
"This channel may not be indexed, or there may be no messages "
|
||||
"matching your query within it."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
@@ -285,6 +358,7 @@ def handle_regular_answer(
|
||||
only_respond_if_citations
|
||||
and not answer.citation_info
|
||||
and not message_info.bypass_filters
|
||||
and not channel_tags
|
||||
):
|
||||
logger.error(
|
||||
f"Unable to find citations to answer: '{answer.answer}' - not answering!"
|
||||
|
||||
@@ -16,7 +16,6 @@ Cache Strategy:
|
||||
using only the SOURCE-type node as the ancestor
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -205,30 +204,6 @@ def cache_hierarchy_nodes_batch(
|
||||
redis_client.expire(raw_id_key, HIERARCHY_CACHE_TTL_SECONDS)
|
||||
|
||||
|
||||
def evict_hierarchy_nodes_from_cache(
|
||||
redis_client: Redis,
|
||||
source: DocumentSource,
|
||||
raw_node_ids: list[str],
|
||||
) -> None:
|
||||
"""Remove specific hierarchy nodes from the Redis cache.
|
||||
|
||||
Deletes entries from both the parent-chain hash and the raw_id→node_id hash.
|
||||
"""
|
||||
if not raw_node_ids:
|
||||
return
|
||||
|
||||
cache_key = _cache_key(source)
|
||||
raw_id_key = _raw_id_cache_key(source)
|
||||
|
||||
# Look up node_ids so we can remove them from the parent-chain hash
|
||||
raw_values = cast(list[str | None], redis_client.hmget(raw_id_key, raw_node_ids))
|
||||
node_id_strs = [v for v in raw_values if v is not None]
|
||||
|
||||
if node_id_strs:
|
||||
redis_client.hdel(cache_key, *node_id_strs)
|
||||
redis_client.hdel(raw_id_key, *raw_node_ids)
|
||||
|
||||
|
||||
def get_node_id_from_raw_id(
|
||||
redis_client: Redis,
|
||||
source: DocumentSource,
|
||||
|
||||
@@ -1905,7 +1905,7 @@ def get_connector_by_id(
|
||||
@router.post("/connector-request")
|
||||
def submit_connector_request(
|
||||
request_data: ConnectorRequestSubmission,
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> StatusResponse:
|
||||
"""
|
||||
Submit a connector request for Cloud deployments.
|
||||
@@ -1918,7 +1918,7 @@ def submit_connector_request(
|
||||
raise HTTPException(status_code=400, detail="Connector name cannot be empty")
|
||||
|
||||
# Get user identifier for telemetry
|
||||
user_email = user.email
|
||||
user_email = user.email if user else None
|
||||
distinct_id = user_email or tenant_id
|
||||
|
||||
# Track connector request via PostHog telemetry (Cloud only)
|
||||
|
||||
@@ -57,6 +57,9 @@ def list_messages(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageListResponse:
|
||||
"""Get all messages for a build session."""
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
messages = session_manager.list_messages(session_id, user.id)
|
||||
|
||||
@@ -54,14 +54,18 @@ def _require_opensearch(db_session: Session) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _get_user_access_info(user: User, db_session: Session) -> tuple[str, list[str]]:
|
||||
def _get_user_access_info(
|
||||
user: User | None, db_session: Session
|
||||
) -> tuple[str | None, list[str]]:
|
||||
if not user:
|
||||
return None, []
|
||||
return user.email, get_user_external_group_ids(db_session, user)
|
||||
|
||||
|
||||
@router.get(HIERARCHY_NODES_LIST_PATH)
|
||||
def list_accessible_hierarchy_nodes(
|
||||
source: DocumentSource,
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HierarchyNodesResponse:
|
||||
_require_opensearch(db_session)
|
||||
@@ -88,7 +92,7 @@ def list_accessible_hierarchy_nodes(
|
||||
@router.post(HIERARCHY_NODE_DOCUMENTS_PATH)
|
||||
def list_accessible_hierarchy_node_documents(
|
||||
documents_request: HierarchyNodeDocumentsRequest,
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HierarchyNodeDocumentsResponse:
|
||||
_require_opensearch(db_session)
|
||||
|
||||
@@ -1013,7 +1013,7 @@ def get_mcp_servers_for_assistant(
|
||||
@router.get("/servers", response_model=MCPServersResponse)
|
||||
def get_mcp_servers_for_user(
|
||||
db: Session = Depends(get_session),
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> MCPServersResponse:
|
||||
"""List all MCP servers for use in agent configuration and chat UI.
|
||||
|
||||
|
||||
@@ -10,8 +10,6 @@ from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -37,38 +35,6 @@ def get_safe_filename(upload: UploadFile) -> str:
|
||||
return upload.filename
|
||||
|
||||
|
||||
def get_upload_size_bytes(upload: UploadFile) -> int | None:
|
||||
"""Best-effort file size in bytes without consuming the stream."""
|
||||
if upload.size is not None:
|
||||
return upload.size
|
||||
|
||||
try:
|
||||
current_pos = upload.file.tell()
|
||||
upload.file.seek(0, 2)
|
||||
size = upload.file.tell()
|
||||
upload.file.seek(current_pos)
|
||||
return size
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Could not determine upload size via stream seek "
|
||||
f"(filename='{get_safe_filename(upload)}', "
|
||||
f"error_type={type(e).__name__}, error={e})"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def is_upload_too_large(upload: UploadFile, max_bytes: int) -> bool:
|
||||
"""Return True when upload size is known and exceeds max_bytes."""
|
||||
size_bytes = get_upload_size_bytes(upload)
|
||||
if size_bytes is None:
|
||||
logger.warning(
|
||||
"Could not determine upload size; skipping size-limit check for "
|
||||
f"'{get_safe_filename(upload)}'"
|
||||
)
|
||||
return False
|
||||
return size_bytes > max_bytes
|
||||
|
||||
|
||||
# Guard against extremely large images
|
||||
Image.MAX_IMAGE_PIXELS = 12000 * 12000
|
||||
|
||||
@@ -193,18 +159,6 @@ def categorize_uploaded_files(
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
|
||||
# Size limit is a hard safety cap and is enforced even when token
|
||||
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
|
||||
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
extension = get_file_ext(filename)
|
||||
|
||||
# If image, estimate tokens via dedicated method first
|
||||
|
||||
@@ -58,6 +58,9 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelDetails
|
||||
from onyx.server.manage.llm.models import LitellmModelsRequest
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
@@ -72,6 +75,7 @@ from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenRouterModelDetails
|
||||
from onyx.server.manage.llm.models import OpenRouterModelsRequest
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.server.manage.llm.utils import generate_bedrock_display_name
|
||||
@@ -98,6 +102,34 @@ def _mask_string(value: str) -> str:
|
||||
return value[:4] + "****" + value[-4:]
|
||||
|
||||
|
||||
def _sync_fetched_models(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
models: list[SyncModelEntry],
|
||||
source_label: str,
|
||||
) -> None:
|
||||
"""Sync fetched models to DB for the given provider.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider_name: Name of the LLM provider
|
||||
models: List of SyncModelEntry objects describing the fetched models
|
||||
source_label: Human-readable label for log messages (e.g. "Bedrock", "LiteLLM")
|
||||
"""
|
||||
try:
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=provider_name,
|
||||
models=models,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new {source_label} models to provider '{provider_name}'"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync {source_label} models to DB: {e}")
|
||||
|
||||
|
||||
# Keys in custom_config that contain sensitive credentials
|
||||
_SENSITIVE_CONFIG_KEYS = {
|
||||
"vertex_credentials",
|
||||
@@ -963,27 +995,20 @@ def get_bedrock_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new Bedrock models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync Bedrock models to DB: {e}")
|
||||
for r in results
|
||||
],
|
||||
source_label="Bedrock",
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -1101,27 +1126,20 @@ def get_ollama_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new Ollama models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync Ollama models to DB: {e}")
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="Ollama",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
@@ -1210,27 +1228,20 @@ def get_openrouter_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new OpenRouter models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync OpenRouter models to DB: {e}")
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="OpenRouter",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
@@ -1324,26 +1335,119 @@ def get_lm_studio_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new LM Studio models to provider '{request.provider_name}'"
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync LM Studio models to DB: {e}")
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="LM Studio",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
@admin_router.post("/litellm/available-models")
|
||||
def get_litellm_available_models(
|
||||
request: LitellmModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LitellmFinalModelResponse]:
|
||||
"""Fetch available models from Litellm proxy /v1/models endpoint."""
|
||||
response_json = _get_litellm_models_response(
|
||||
api_key=request.api_key, api_base=request.api_base
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your Litellm endpoint",
|
||||
)
|
||||
|
||||
results: list[LitellmFinalModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_details = LitellmModelDetails.model_validate(model)
|
||||
|
||||
results.append(
|
||||
LitellmFinalModelResponse(
|
||||
provider_name=model_details.owned_by,
|
||||
model_name=model_details.id,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse Litellm model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from Litellm",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.model_name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.model_name,
|
||||
display_name=r.model_name,
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="LiteLLM",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
|
||||
)
|
||||
elif e.response.status_code == 404:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"LiteLLM models endpoint not found at {url}. "
|
||||
"Please verify the API base URL.",
|
||||
)
|
||||
else:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
)
|
||||
|
||||
@@ -420,3 +420,32 @@ class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
|
||||
class SyncModelEntry(BaseModel):
|
||||
"""Typed model for syncing fetched models to the DB."""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool = False
|
||||
|
||||
|
||||
class LitellmModelsRequest(BaseModel):
|
||||
api_key: str
|
||||
api_base: str
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class LitellmModelDetails(BaseModel):
|
||||
"""Response model for Litellm proxy /api/v1/models endpoint"""
|
||||
|
||||
id: str # Model ID (e.g. "gpt-4o")
|
||||
object: str # "model"
|
||||
created: int # Unix timestamp in seconds
|
||||
owned_by: str # Provider name (e.g. "openai")
|
||||
|
||||
|
||||
class LitellmFinalModelResponse(BaseModel):
|
||||
provider_name: str # Provider name (e.g. "openai")
|
||||
model_name: str # Model ID (e.g. "gpt-4o")
|
||||
|
||||
@@ -42,7 +42,6 @@ class StreamingType(Enum):
|
||||
REASONING_DONE = "reasoning_done"
|
||||
CITATION_INFO = "citation_info"
|
||||
TOOL_CALL_DEBUG = "tool_call_debug"
|
||||
TOOL_CALL_ARGUMENT_DELTA = "tool_call_argument_delta"
|
||||
|
||||
MEMORY_TOOL_START = "memory_tool_start"
|
||||
MEMORY_TOOL_DELTA = "memory_tool_delta"
|
||||
@@ -277,15 +276,6 @@ class CustomToolDelta(BaseObj):
|
||||
error: CustomToolErrorInfo | None = None
|
||||
|
||||
|
||||
class ToolCallArgumentDelta(BaseObj):
|
||||
type: Literal["tool_call_argument_delta"] = (
|
||||
StreamingType.TOOL_CALL_ARGUMENT_DELTA.value
|
||||
)
|
||||
|
||||
tool_type: str
|
||||
argument_deltas: dict[str, Any]
|
||||
|
||||
|
||||
################################################
|
||||
# File Reader Packets
|
||||
################################################
|
||||
@@ -407,7 +397,6 @@ PacketObj = Union[
|
||||
# Citation Packets
|
||||
CitationInfo,
|
||||
ToolCallDebug,
|
||||
ToolCallArgumentDelta,
|
||||
# Deep Research Packets
|
||||
DeepResearchPlanStart,
|
||||
DeepResearchPlanDelta,
|
||||
|
||||
@@ -78,7 +78,6 @@ class Settings(BaseModel):
|
||||
|
||||
# User Knowledge settings
|
||||
user_knowledge_enabled: bool | None = True
|
||||
user_file_max_upload_size_mb: int | None = None
|
||||
|
||||
# Connector settings
|
||||
show_extra_connectors: bool | None = True
|
||||
|
||||
@@ -3,7 +3,6 @@ from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -51,7 +50,6 @@ def load_settings() -> Settings:
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
|
||||
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
return settings
|
||||
|
||||
@@ -56,23 +56,3 @@ def get_built_in_tool_ids() -> list[str]:
|
||||
|
||||
def get_built_in_tool_by_id(in_code_tool_id: str) -> Type[BUILT_IN_TOOL_TYPES]:
|
||||
return BUILT_IN_TOOL_MAP[in_code_tool_id]
|
||||
|
||||
|
||||
def _build_tool_name_to_class() -> dict[str, Type[BUILT_IN_TOOL_TYPES]]:
|
||||
"""Build a mapping from LLM-facing tool name to tool class."""
|
||||
result: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {}
|
||||
for cls in BUILT_IN_TOOL_MAP.values():
|
||||
name_attr = cls.__dict__.get("name")
|
||||
if isinstance(name_attr, property) and name_attr.fget is not None:
|
||||
tool_name = name_attr.fget(cls)
|
||||
elif isinstance(name_attr, str):
|
||||
tool_name = name_attr
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Built-in tool {cls.__name__} must define a valid LLM-facing tool name"
|
||||
)
|
||||
result[tool_name] = cls
|
||||
return result
|
||||
|
||||
|
||||
TOOL_NAME_TO_CLASS: dict[str, Type[BUILT_IN_TOOL_TYPES]] = _build_tool_name_to_class()
|
||||
|
||||
@@ -92,7 +92,3 @@ class Tool(abc.ABC, Generic[TOverride]):
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def should_emit_argument_deltas(cls) -> bool:
|
||||
return False
|
||||
|
||||
@@ -376,8 +376,3 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
rich_response=None,
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def should_emit_argument_deltas(cls) -> bool:
|
||||
return True
|
||||
|
||||
@@ -11,20 +11,16 @@ logger = setup_logger()
|
||||
|
||||
|
||||
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
|
||||
def _encrypt_string(input_str: str, key: str | None = None) -> bytes: # noqa: ARG001
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
if ENCRYPTION_KEY_SECRET:
|
||||
logger.warning("MIT version of Onyx does not support encryption of secrets.")
|
||||
elif key is not None:
|
||||
logger.debug("MIT encrypt called with explicit key — key ignored.")
|
||||
return input_str.encode()
|
||||
|
||||
|
||||
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
|
||||
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str: # noqa: ARG001
|
||||
if ENCRYPTION_KEY_SECRET:
|
||||
logger.warning("MIT version of Onyx does not support decryption of secrets.")
|
||||
elif key is not None:
|
||||
logger.debug("MIT decrypt called with explicit key — key ignored.")
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
# No need to double warn. If you wish to learn more about encryption features
|
||||
# refer to the Onyx EE code
|
||||
return input_bytes.decode()
|
||||
|
||||
|
||||
@@ -90,15 +86,15 @@ def _mask_list(items: list[Any]) -> list[Any]:
|
||||
return masked
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(intput_str: str, key: str | None = None) -> bytes:
|
||||
def encrypt_string_to_bytes(intput_str: str) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
)
|
||||
return versioned_encryption_fn(intput_str, key=key)
|
||||
return versioned_encryption_fn(intput_str)
|
||||
|
||||
|
||||
def decrypt_bytes_to_string(intput_bytes: bytes, key: str | None = None) -> str:
|
||||
def decrypt_bytes_to_string(intput_bytes: bytes) -> str:
|
||||
versioned_decryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_decrypt_bytes"
|
||||
)
|
||||
return versioned_decryption_fn(intput_bytes, key=key)
|
||||
return versioned_decryption_fn(intput_bytes)
|
||||
|
||||
@@ -128,8 +128,6 @@ class SensitiveValue(Generic[T]):
|
||||
value = self._decrypt()
|
||||
|
||||
if not apply_mask:
|
||||
# Callers must not mutate the returned dict — doing so would
|
||||
# desync the cache from the encrypted bytes and the DB.
|
||||
return value
|
||||
|
||||
# Apply masking
|
||||
@@ -176,20 +174,18 @@ class SensitiveValue(Generic[T]):
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Compare SensitiveValues by their decrypted content."""
|
||||
# NOTE: if you attempt to compare a string/dict to a SensitiveValue,
|
||||
# this comparison will return NotImplemented, which then evaluates to False.
|
||||
# This is the convention and required for SQLAlchemy's attribute tracking.
|
||||
if not isinstance(other, SensitiveValue):
|
||||
return NotImplemented
|
||||
return self._decrypt() == other._decrypt()
|
||||
"""Prevent direct comparison which might expose value."""
|
||||
if isinstance(other, SensitiveValue):
|
||||
# Compare encrypted bytes for equality check
|
||||
return self._encrypted_bytes == other._encrypted_bytes
|
||||
raise SensitiveAccessError(
|
||||
"Cannot compare SensitiveValue with non-SensitiveValue. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value for comparison."
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash based on decrypted content."""
|
||||
value = self._decrypt()
|
||||
if isinstance(value, dict):
|
||||
return hash(json.dumps(value, sort_keys=True))
|
||||
return hash(value)
|
||||
"""Allow hashing based on encrypted bytes."""
|
||||
return hash(self._encrypted_bytes)
|
||||
|
||||
# Prevent JSON serialization
|
||||
def __json__(self) -> Any:
|
||||
|
||||
@@ -2,6 +2,7 @@ import contextvars
|
||||
import threading
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
|
||||
@@ -14,7 +15,6 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.key_value_store.interface import unwrap_str
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
@@ -25,7 +25,6 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.onyx.app/anonymous_telemetry"
|
||||
_CACHED_UUID: str | None = None
|
||||
_CACHED_INSTANCE_DOMAIN: str | None = None
|
||||
@@ -63,10 +62,10 @@ def get_or_generate_uuid() -> str:
|
||||
kv_store = get_kv_store()
|
||||
|
||||
try:
|
||||
_CACHED_UUID = unwrap_str(kv_store.load(KV_CUSTOMER_UUID_KEY))
|
||||
_CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
_CACHED_UUID = str(uuid.uuid4())
|
||||
kv_store.store(KV_CUSTOMER_UUID_KEY, {"value": _CACHED_UUID}, encrypt=True)
|
||||
kv_store.store(KV_CUSTOMER_UUID_KEY, _CACHED_UUID, encrypt=True)
|
||||
|
||||
return _CACHED_UUID
|
||||
|
||||
@@ -80,16 +79,14 @@ def _get_or_generate_instance_domain() -> str | None: #
|
||||
kv_store = get_kv_store()
|
||||
|
||||
try:
|
||||
_CACHED_INSTANCE_DOMAIN = unwrap_str(kv_store.load(KV_INSTANCE_DOMAIN_KEY))
|
||||
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
first_user = db_session.query(User).first()
|
||||
if first_user:
|
||||
_CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1]
|
||||
kv_store.store(
|
||||
KV_INSTANCE_DOMAIN_KEY,
|
||||
{"value": _CACHED_INSTANCE_DOMAIN},
|
||||
encrypt=True,
|
||||
KV_INSTANCE_DOMAIN_KEY, _CACHED_INSTANCE_DOMAIN, encrypt=True
|
||||
)
|
||||
|
||||
return _CACHED_INSTANCE_DOMAIN
|
||||
|
||||
@@ -263,7 +263,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.6.3
|
||||
onyx-devtools==0.7.0
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
@@ -406,7 +406,7 @@ referencing==0.36.2
|
||||
# jsonschema-specifications
|
||||
regex==2025.11.3
|
||||
# via tiktoken
|
||||
release-tag==0.4.3
|
||||
release-tag==0.5.2
|
||||
# via onyx
|
||||
reorder-python-imports-black==3.14.0
|
||||
# via onyx
|
||||
|
||||
@@ -1,93 +1,48 @@
|
||||
"""Decrypt a raw hex-encoded credential value.
|
||||
|
||||
Usage:
|
||||
python -m scripts.decrypt <hex_value>
|
||||
python -m scripts.decrypt <hex_value> --key "my-encryption-key"
|
||||
python -m scripts.decrypt <hex_value> --key ""
|
||||
|
||||
Pass --key "" to skip decryption and just decode the raw bytes as UTF-8.
|
||||
Omit --key to use the current ENCRYPTION_KEY_SECRET from the environment.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import binascii
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string # noqa: E402
|
||||
from onyx.utils.variable_functionality import global_version # noqa: E402
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
|
||||
|
||||
def decrypt_raw_credential(encrypted_value: str, key: str | None = None) -> None:
|
||||
"""Decrypt and display a raw encrypted credential value.
|
||||
def decrypt_raw_credential(encrypted_value: str) -> None:
|
||||
"""Decrypt and display a raw encrypted credential value
|
||||
|
||||
Args:
|
||||
encrypted_value: The hex-encoded encrypted credential value.
|
||||
key: Encryption key to use. None means use ENCRYPTION_KEY_SECRET,
|
||||
empty string means just decode as UTF-8.
|
||||
encrypted_value: The hex encoded encrypted credential value
|
||||
"""
|
||||
# Strip common hex prefixes
|
||||
if encrypted_value.startswith("\\x"):
|
||||
encrypted_value = encrypted_value[2:]
|
||||
elif encrypted_value.startswith("x"):
|
||||
encrypted_value = encrypted_value[1:]
|
||||
print(encrypted_value)
|
||||
|
||||
try:
|
||||
raw_bytes = binascii.unhexlify(encrypted_value)
|
||||
# If string starts with 'x', remove it as it's just a prefix indicating hex
|
||||
if encrypted_value.startswith("x"):
|
||||
encrypted_value = encrypted_value[1:]
|
||||
elif encrypted_value.startswith("\\x"):
|
||||
encrypted_value = encrypted_value[2:]
|
||||
|
||||
# Convert hex string to bytes
|
||||
encrypted_bytes = binascii.unhexlify(encrypted_value)
|
||||
|
||||
# Decrypt the bytes
|
||||
decrypted_str = decrypt_bytes_to_string(encrypted_bytes)
|
||||
|
||||
# Parse and pretty print the decrypted JSON
|
||||
decrypted_json = json.loads(decrypted_str)
|
||||
print("Decrypted credential value:")
|
||||
print(json.dumps(decrypted_json, indent=2))
|
||||
|
||||
except binascii.Error:
|
||||
print("Error: Invalid hex-encoded string")
|
||||
sys.exit(1)
|
||||
print("Error: Invalid hex encoded string")
|
||||
|
||||
if key == "":
|
||||
# Empty key → just decode as UTF-8, no decryption
|
||||
try:
|
||||
decrypted_str = raw_bytes.decode("utf-8")
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"Error decoding bytes as UTF-8: {e}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(key)
|
||||
try:
|
||||
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=key)
|
||||
except Exception as e:
|
||||
print(f"Error decrypting value: {e}")
|
||||
sys.exit(1)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Decrypted raw value (not JSON): {e}")
|
||||
|
||||
# Try to pretty-print as JSON, otherwise print raw
|
||||
try:
|
||||
parsed = json.loads(decrypted_str)
|
||||
print(json.dumps(parsed, indent=2))
|
||||
except json.JSONDecodeError:
|
||||
print(decrypted_str)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Decrypt a hex-encoded credential value."
|
||||
)
|
||||
parser.add_argument(
|
||||
"value",
|
||||
help="Hex-encoded encrypted value to decrypt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--key",
|
||||
default=None,
|
||||
help=(
|
||||
"Encryption key. Omit to use ENCRYPTION_KEY_SECRET from env. "
|
||||
'Pass "" (empty) to just decode as UTF-8 without decryption.'
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
global_version.set_ee()
|
||||
decrypt_raw_credential(args.value, key=args.key)
|
||||
global_version.unset_ee()
|
||||
except Exception as e:
|
||||
print(f"Error decrypting value: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python decrypt.py <hex_encoded_encrypted_value>")
|
||||
sys.exit(1)
|
||||
|
||||
encrypted_value = sys.argv[1]
|
||||
decrypt_raw_credential(encrypted_value)
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
"""Re-encrypt secrets under the current ENCRYPTION_KEY_SECRET.
|
||||
|
||||
Decrypts all encrypted columns using the old key (or raw decode if the old key
|
||||
is empty), then re-encrypts them with the current ENCRYPTION_KEY_SECRET.
|
||||
|
||||
Usage (docker):
|
||||
docker exec -it onyx-api_server-1 \
|
||||
python -m scripts.reencrypt_secrets --old-key "previous-key"
|
||||
|
||||
Usage (kubernetes):
|
||||
kubectl exec -it <pod> -- \
|
||||
python -m scripts.reencrypt_secrets --old-key "previous-key"
|
||||
|
||||
Omit --old-key (or pass "") if secrets were not previously encrypted.
|
||||
|
||||
For multi-tenant deployments, pass --tenant-id to target a specific tenant,
|
||||
or --all-tenants to iterate every tenant.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from onyx.db.rotate_encryption_key import rotate_encryption_key # noqa: E402
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant # noqa: E402
|
||||
from onyx.db.engine.sql_engine import SqlEngine # noqa: E402
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids # noqa: E402
|
||||
from onyx.utils.variable_functionality import global_version # noqa: E402
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # noqa: E402
|
||||
|
||||
|
||||
def _run_for_tenant(tenant_id: str, old_key: str | None, dry_run: bool = False) -> None:
|
||||
print(f"Re-encrypting secrets for tenant: {tenant_id}")
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
results = rotate_encryption_key(db_session, old_key=old_key, dry_run=dry_run)
|
||||
|
||||
if results:
|
||||
for col, count in results.items():
|
||||
print(
|
||||
f" {col}: {count} row(s) {'would be ' if dry_run else ''}re-encrypted"
|
||||
)
|
||||
else:
|
||||
print("No rows needed re-encryption.")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Re-encrypt secrets under the current encryption key."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old-key",
|
||||
default=None,
|
||||
help="Previous encryption key. Omit or pass empty string if not applicable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Show what would be re-encrypted without making changes.",
|
||||
)
|
||||
|
||||
tenant_group = parser.add_mutually_exclusive_group()
|
||||
tenant_group.add_argument(
|
||||
"--tenant-id",
|
||||
default=None,
|
||||
help="Target a specific tenant schema.",
|
||||
)
|
||||
tenant_group.add_argument(
|
||||
"--all-tenants",
|
||||
action="store_true",
|
||||
help="Iterate all tenants.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
old_key = args.old_key if args.old_key else None
|
||||
|
||||
global_version.set_ee()
|
||||
SqlEngine.init_engine(pool_size=5, max_overflow=2)
|
||||
|
||||
if args.dry_run:
|
||||
print("DRY RUN — no changes will be made")
|
||||
|
||||
if args.all_tenants:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
print(f"Found {len(tenant_ids)} tenant(s)")
|
||||
failed_tenants: list[str] = []
|
||||
for tid in tenant_ids:
|
||||
try:
|
||||
_run_for_tenant(tid, old_key, dry_run=args.dry_run)
|
||||
except Exception as e:
|
||||
print(f" ERROR for tenant {tid}: {e}")
|
||||
failed_tenants.append(tid)
|
||||
if failed_tenants:
|
||||
print(f"FAILED tenants ({len(failed_tenants)}): {failed_tenants}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
tenant_id = args.tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
_run_for_tenant(tenant_id, old_key, dry_run=args.dry_run)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -48,7 +48,7 @@ def test_gitlab_connector_basic(gitlab_connector: GitlabConnector) -> None:
|
||||
|
||||
# --- Specific Document Details to Validate ---
|
||||
target_mr_id = f"https://{gitlab_base_url}/{project_path}/-/merge_requests/1"
|
||||
target_issue_id = f"https://{gitlab_base_url}/{project_path}/-/work_items/2"
|
||||
target_issue_id = f"https://{gitlab_base_url}/{project_path}/-/issues/2"
|
||||
target_code_file_semantic_id = "README.md"
|
||||
# ---
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@ Verifies that:
|
||||
3. Upserting is idempotent (running twice doesn't duplicate nodes)
|
||||
4. Document-to-hierarchy-node linkage is updated during pruning
|
||||
5. link_hierarchy_nodes_to_documents links nodes that are also documents
|
||||
6. HierarchyNodeByConnectorCredentialPair join table population and pruning
|
||||
7. Orphaned hierarchy node deletion and re-parenting
|
||||
|
||||
Uses a mock SlimConnectorWithPermSync that yields known hierarchy nodes and slim documents,
|
||||
combined with a real PostgreSQL database for verifying persistence.
|
||||
@@ -26,27 +24,16 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import ensure_source_node_exists
|
||||
from onyx.db.hierarchy import get_all_hierarchy_nodes_for_source
|
||||
from onyx.db.hierarchy import get_hierarchy_node_by_raw_id
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Document as DbDocument
|
||||
from onyx.db.models import HierarchyNode as DBHierarchyNode
|
||||
from onyx.db.models import HierarchyNodeByConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.kg.models import KGStage
|
||||
|
||||
@@ -155,80 +142,13 @@ class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_cc_pair(
|
||||
db_session: Session,
|
||||
source: DocumentSource = TEST_SOURCE,
|
||||
) -> ConnectorCredentialPair:
|
||||
"""Create a real Connector + Credential + ConnectorCredentialPair for testing."""
|
||||
connector = Connector(
|
||||
name=f"Test {source.value} Connector",
|
||||
source=source,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={},
|
||||
)
|
||||
db_session.add(connector)
|
||||
db_session.flush()
|
||||
|
||||
credential = Credential(
|
||||
source=source,
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
db_session.expire(credential)
|
||||
|
||||
cc_pair = ConnectorCredentialPair(
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
name=f"Test {source.value} CC Pair",
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
access_type=AccessType.PUBLIC,
|
||||
)
|
||||
db_session.add(cc_pair)
|
||||
db_session.commit()
|
||||
db_session.refresh(cc_pair)
|
||||
return cc_pair
|
||||
|
||||
|
||||
def _cleanup_test_data(db_session: Session) -> None:
|
||||
"""Remove all test hierarchy nodes and documents to isolate tests."""
|
||||
for doc_id in SLIM_DOC_IDS:
|
||||
db_session.query(DbDocument).filter(DbDocument.id == doc_id).delete()
|
||||
|
||||
test_connector_ids_q = db_session.query(Connector.id).filter(
|
||||
Connector.source == TEST_SOURCE,
|
||||
Connector.name.like("Test %"),
|
||||
)
|
||||
|
||||
db_session.query(HierarchyNodeByConnectorCredentialPair).filter(
|
||||
HierarchyNodeByConnectorCredentialPair.connector_id.in_(test_connector_ids_q)
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(DBHierarchyNode).filter(
|
||||
DBHierarchyNode.source == TEST_SOURCE
|
||||
).delete()
|
||||
db_session.flush()
|
||||
|
||||
# Collect credential IDs before deleting cc_pairs (bulk query.delete()
|
||||
# bypasses ORM-level cascade, so credentials won't be auto-removed).
|
||||
credential_ids = [
|
||||
row[0]
|
||||
for row in db_session.query(ConnectorCredentialPair.credential_id)
|
||||
.filter(ConnectorCredentialPair.connector_id.in_(test_connector_ids_q))
|
||||
.all()
|
||||
]
|
||||
|
||||
db_session.query(ConnectorCredentialPair).filter(
|
||||
ConnectorCredentialPair.connector_id.in_(test_connector_ids_q)
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(Connector).filter(
|
||||
Connector.source == TEST_SOURCE,
|
||||
Connector.name.like("Test %"),
|
||||
).delete(synchronize_session="fetch")
|
||||
if credential_ids:
|
||||
db_session.query(Credential).filter(Credential.id.in_(credential_ids)).delete(
|
||||
synchronize_session="fetch"
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -259,8 +179,15 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
|
||||
|
||||
result = extract_ids_from_runnable_connector(connector, callback=None)
|
||||
|
||||
# raw_id_to_parent should contain ONLY document IDs, not hierarchy node IDs
|
||||
assert result.raw_id_to_parent.keys() == set(SLIM_DOC_IDS)
|
||||
# Doc IDs should include both slim doc IDs and hierarchy node raw_node_ids
|
||||
# (hierarchy node IDs are added to raw_id_to_parent so they aren't pruned)
|
||||
expected_ids = {
|
||||
CHANNEL_A_ID,
|
||||
CHANNEL_B_ID,
|
||||
CHANNEL_C_ID,
|
||||
*SLIM_DOC_IDS,
|
||||
}
|
||||
assert result.raw_id_to_parent.keys() == expected_ids
|
||||
|
||||
# Hierarchy nodes should be the 3 channels
|
||||
assert len(result.hierarchy_nodes) == 3
|
||||
@@ -468,9 +395,9 @@ def test_extraction_preserves_parent_hierarchy_raw_node_id(
|
||||
result.raw_id_to_parent[doc_id] == expected_parent
|
||||
), f"raw_id_to_parent[{doc_id}] should be {expected_parent}"
|
||||
|
||||
# Hierarchy node IDs should NOT be in raw_id_to_parent
|
||||
# Hierarchy node entries have None parent (they aren't documents)
|
||||
for channel_id in [CHANNEL_A_ID, CHANNEL_B_ID, CHANNEL_C_ID]:
|
||||
assert channel_id not in result.raw_id_to_parent
|
||||
assert result.raw_id_to_parent[channel_id] is None
|
||||
|
||||
|
||||
def test_update_document_parent_hierarchy_nodes(db_session: Session) -> None:
|
||||
@@ -638,241 +565,3 @@ def test_link_hierarchy_nodes_skips_non_hierarchy_sources(
|
||||
commit=False,
|
||||
)
|
||||
assert linked == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Join table + pruning tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_upsert_hierarchy_node_cc_pair_entries(db_session: Session) -> None:
|
||||
"""upsert_hierarchy_node_cc_pair_entries should insert rows and be idempotent."""
|
||||
_cleanup_test_data(db_session)
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
cc_pair = _create_cc_pair(db_session)
|
||||
|
||||
upserted = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=_make_hierarchy_nodes(),
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
node_ids = [n.id for n in upserted]
|
||||
|
||||
# First call — should insert rows
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=node_ids,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
rows = (
|
||||
db_session.query(HierarchyNodeByConnectorCredentialPair)
|
||||
.filter(
|
||||
HierarchyNodeByConnectorCredentialPair.connector_id == cc_pair.connector_id,
|
||||
HierarchyNodeByConnectorCredentialPair.credential_id
|
||||
== cc_pair.credential_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
assert len(rows) == 3
|
||||
|
||||
# Second call — idempotent, same count
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=node_ids,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
rows_after = (
|
||||
db_session.query(HierarchyNodeByConnectorCredentialPair)
|
||||
.filter(
|
||||
HierarchyNodeByConnectorCredentialPair.connector_id == cc_pair.connector_id,
|
||||
HierarchyNodeByConnectorCredentialPair.credential_id
|
||||
== cc_pair.credential_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
assert len(rows_after) == 3
|
||||
|
||||
|
||||
def test_remove_stale_entries_and_delete_orphans(db_session: Session) -> None:
|
||||
"""After removing stale join-table entries, orphaned hierarchy nodes should
|
||||
be deleted and the SOURCE node should survive."""
|
||||
_cleanup_test_data(db_session)
|
||||
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
cc_pair = _create_cc_pair(db_session)
|
||||
|
||||
upserted = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=_make_hierarchy_nodes(),
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
all_ids = [n.id for n in upserted]
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=all_ids,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Now simulate a pruning run where only channel A survived
|
||||
channel_a = get_hierarchy_node_by_raw_id(db_session, CHANNEL_A_ID, TEST_SOURCE)
|
||||
assert channel_a is not None
|
||||
live_ids = {channel_a.id}
|
||||
|
||||
stale_removed = remove_stale_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
live_hierarchy_node_ids=live_ids,
|
||||
commit=True,
|
||||
)
|
||||
assert stale_removed == 2
|
||||
|
||||
# Delete orphaned nodes
|
||||
deleted_raw_ids = delete_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
)
|
||||
assert set(deleted_raw_ids) == {CHANNEL_B_ID, CHANNEL_C_ID}
|
||||
|
||||
# Verify only channel A + SOURCE remain
|
||||
remaining = get_all_hierarchy_nodes_for_source(db_session, TEST_SOURCE)
|
||||
remaining_raw = {n.raw_node_id for n in remaining}
|
||||
assert remaining_raw == {CHANNEL_A_ID, source_node.raw_node_id}
|
||||
|
||||
|
||||
def test_multi_cc_pair_prevents_premature_deletion(db_session: Session) -> None:
|
||||
"""A hierarchy node shared by two cc_pairs should NOT be deleted when only
|
||||
one cc_pair removes its association."""
|
||||
_cleanup_test_data(db_session)
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
cc_pair_1 = _create_cc_pair(db_session)
|
||||
cc_pair_2 = _create_cc_pair(db_session)
|
||||
|
||||
upserted = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=_make_hierarchy_nodes(),
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
all_ids = [n.id for n in upserted]
|
||||
|
||||
# cc_pair 1 owns all 3
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=all_ids,
|
||||
connector_id=cc_pair_1.connector_id,
|
||||
credential_id=cc_pair_1.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
# cc_pair 2 also owns all 3
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=all_ids,
|
||||
connector_id=cc_pair_2.connector_id,
|
||||
credential_id=cc_pair_2.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# cc_pair 1 prunes — keeps none
|
||||
remove_stale_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair_1.connector_id,
|
||||
credential_id=cc_pair_1.credential_id,
|
||||
live_hierarchy_node_ids=set(),
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Orphan deletion should find nothing because cc_pair 2 still references them
|
||||
deleted = delete_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
)
|
||||
assert deleted == []
|
||||
|
||||
# All 3 nodes + SOURCE should still exist
|
||||
remaining = get_all_hierarchy_nodes_for_source(db_session, TEST_SOURCE)
|
||||
assert len(remaining) == 4
|
||||
|
||||
|
||||
def test_reparent_orphaned_children(db_session: Session) -> None:
|
||||
"""After deleting a parent hierarchy node, its children should be
|
||||
re-parented to the SOURCE node."""
|
||||
_cleanup_test_data(db_session)
|
||||
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
cc_pair = _create_cc_pair(db_session)
|
||||
|
||||
# Create a parent node and a child node
|
||||
parent_node = PydanticHierarchyNode(
|
||||
raw_node_id="PARENT",
|
||||
raw_parent_id=None,
|
||||
display_name="Parent",
|
||||
node_type=HierarchyNodeType.CHANNEL,
|
||||
)
|
||||
child_node = PydanticHierarchyNode(
|
||||
raw_node_id="CHILD",
|
||||
raw_parent_id="PARENT",
|
||||
display_name="Child",
|
||||
node_type=HierarchyNodeType.CHANNEL,
|
||||
)
|
||||
upserted = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=[parent_node, child_node],
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
assert len(upserted) == 2
|
||||
|
||||
parent_db = get_hierarchy_node_by_raw_id(db_session, "PARENT", TEST_SOURCE)
|
||||
child_db = get_hierarchy_node_by_raw_id(db_session, "CHILD", TEST_SOURCE)
|
||||
assert parent_db is not None and child_db is not None
|
||||
assert child_db.parent_id == parent_db.id
|
||||
|
||||
# Associate only the child with a cc_pair (parent is orphaned)
|
||||
upsert_hierarchy_node_cc_pair_entries(
|
||||
db_session=db_session,
|
||||
hierarchy_node_ids=[child_db.id],
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Delete orphaned nodes (parent has no cc_pair entry)
|
||||
deleted = delete_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
)
|
||||
assert "PARENT" in deleted
|
||||
|
||||
# Child should now have parent_id=NULL (SET NULL cascade)
|
||||
db_session.expire_all()
|
||||
child_db = get_hierarchy_node_by_raw_id(db_session, "CHILD", TEST_SOURCE)
|
||||
assert child_db is not None
|
||||
assert child_db.parent_id is None
|
||||
|
||||
# Re-parent orphans to SOURCE
|
||||
reparented = reparent_orphaned_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
)
|
||||
assert len(reparented) == 1
|
||||
|
||||
db_session.expire_all()
|
||||
child_db = get_hierarchy_node_by_raw_id(db_session, "CHILD", TEST_SOURCE)
|
||||
assert child_db is not None
|
||||
assert child_db.parent_id == source_node.id
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
"""Test that Credential with nested JSON round-trips through SensitiveValue correctly.
|
||||
|
||||
Exercises the full encrypt → store → read → decrypt → SensitiveValue path
|
||||
with realistic nested OAuth credential data, and verifies SQLAlchemy dirty
|
||||
tracking works with nested dict comparison.
|
||||
|
||||
Requires a running Postgres instance.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import Credential
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
|
||||
# NOTE: this is not the real shape of a Drive credential,
|
||||
# but it is intended to test nested JSON credential handling
|
||||
|
||||
_NESTED_CRED_JSON = {
|
||||
"oauth_tokens": {
|
||||
"access_token": "ya29.abc123",
|
||||
"refresh_token": "1//xEg-def456",
|
||||
},
|
||||
"scopes": ["read", "write", "admin"],
|
||||
"client_config": {
|
||||
"client_id": "123.apps.googleusercontent.com",
|
||||
"client_secret": "GOCSPX-secret",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_nested_credential_json_round_trip(db_session: Session) -> None:
|
||||
"""Nested OAuth credential survives encrypt → store → read → decrypt."""
|
||||
credential = Credential(
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
credential_json=_NESTED_CRED_JSON,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
|
||||
# Immediate read (no DB round-trip) — tests the set event wrapping
|
||||
assert isinstance(credential.credential_json, SensitiveValue)
|
||||
assert credential.credential_json.get_value(apply_mask=False) == _NESTED_CRED_JSON
|
||||
|
||||
# DB round-trip — tests process_result_value
|
||||
db_session.expire(credential)
|
||||
reloaded = credential.credential_json
|
||||
assert isinstance(reloaded, SensitiveValue)
|
||||
assert reloaded.get_value(apply_mask=False) == _NESTED_CRED_JSON
|
||||
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def test_reassign_same_nested_json_not_dirty(db_session: Session) -> None:
|
||||
"""Re-assigning the same nested dict should not mark the session dirty."""
|
||||
credential = Credential(
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
credential_json=_NESTED_CRED_JSON,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
|
||||
# Clear dirty state from the insert
|
||||
db_session.expire(credential)
|
||||
_ = credential.credential_json # force reload
|
||||
|
||||
# Re-assign identical value
|
||||
credential.credential_json = _NESTED_CRED_JSON # type: ignore[assignment]
|
||||
assert not db_session.is_modified(credential)
|
||||
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def test_assign_different_nested_json_is_dirty(db_session: Session) -> None:
|
||||
"""Assigning a different nested dict should mark the session dirty."""
|
||||
credential = Credential(
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
credential_json=_NESTED_CRED_JSON,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
|
||||
db_session.expire(credential)
|
||||
_ = credential.credential_json # force reload
|
||||
|
||||
modified_cred = {**_NESTED_CRED_JSON, "scopes": ["read"]}
|
||||
credential.credential_json = modified_cred # type: ignore[assignment]
|
||||
assert db_session.is_modified(credential)
|
||||
|
||||
db_session.rollback()
|
||||
@@ -1,305 +0,0 @@
|
||||
"""Tests for rotate_encryption_key against real Postgres.
|
||||
|
||||
Uses real ORM models (Credential, InternetSearchProvider) and the actual
|
||||
Postgres database. Discovery is mocked in rotation tests to scope mutations
|
||||
to only the test rows — the real _discover_encrypted_columns walk is tested
|
||||
separately in TestDiscoverEncryptedColumns.
|
||||
|
||||
Requires a running Postgres instance. Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/db/test_rotate_encryption_key.py
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import LargeBinary
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.utils.encryption import _decrypt_bytes
|
||||
from ee.onyx.utils.encryption import _encrypt_string
|
||||
from ee.onyx.utils.encryption import _get_trimmed_key
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import EncryptedJson
|
||||
from onyx.db.models import EncryptedString
|
||||
from onyx.db.models import InternetSearchProvider
|
||||
from onyx.db.rotate_encryption_key import _discover_encrypted_columns
|
||||
from onyx.db.rotate_encryption_key import rotate_encryption_key
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
EE_MODULE = "ee.onyx.utils.encryption"
|
||||
ROTATE_MODULE = "onyx.db.rotate_encryption_key"
|
||||
|
||||
OLD_KEY = "o" * 16
|
||||
NEW_KEY = "n" * 16
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_ee() -> Generator[None, None, None]:
|
||||
prev = global_version._is_ee
|
||||
global_version.set_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
yield
|
||||
global_version._is_ee = prev
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_key_cache() -> None:
|
||||
_get_trimmed_key.cache_clear()
|
||||
|
||||
|
||||
def _raw_credential_bytes(db_session: Session, credential_id: int) -> bytes | None:
|
||||
"""Read raw bytes from credential_json, bypassing the TypeDecorator."""
|
||||
col = Credential.__table__.c.credential_json
|
||||
stmt = select(col.cast(LargeBinary)).where(
|
||||
Credential.__table__.c.id == credential_id
|
||||
)
|
||||
return db_session.execute(stmt).scalar()
|
||||
|
||||
|
||||
def _raw_isp_bytes(db_session: Session, isp_id: int) -> bytes | None:
|
||||
"""Read raw bytes from InternetSearchProvider.api_key."""
|
||||
col = InternetSearchProvider.__table__.c.api_key
|
||||
stmt = select(col.cast(LargeBinary)).where(
|
||||
InternetSearchProvider.__table__.c.id == isp_id
|
||||
)
|
||||
return db_session.execute(stmt).scalar()
|
||||
|
||||
|
||||
class TestDiscoverEncryptedColumns:
|
||||
"""Verify _discover_encrypted_columns finds real production models."""
|
||||
|
||||
def test_discovers_credential_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
found = {
|
||||
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
|
||||
for model_cls, col_name, _, is_json in results
|
||||
}
|
||||
assert ("credential", "credential_json", True) in found
|
||||
|
||||
def test_discovers_internet_search_provider_api_key(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
found = {
|
||||
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
|
||||
for model_cls, col_name, _, is_json in results
|
||||
}
|
||||
assert ("internet_search_provider", "api_key", False) in found
|
||||
|
||||
def test_all_encrypted_string_columns_are_not_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
for model_cls, col_name, _, is_json in results:
|
||||
col = getattr(model_cls, col_name).property.columns[0]
|
||||
if isinstance(col.type, EncryptedString):
|
||||
assert not is_json, (
|
||||
f"{model_cls.__tablename__}.{col_name} is EncryptedString " # type: ignore[attr-defined]
|
||||
f"but is_json={is_json}"
|
||||
)
|
||||
|
||||
def test_all_encrypted_json_columns_are_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
for model_cls, col_name, _, is_json in results:
|
||||
col = getattr(model_cls, col_name).property.columns[0]
|
||||
if isinstance(col.type, EncryptedJson):
|
||||
assert is_json, (
|
||||
f"{model_cls.__tablename__}.{col_name} is EncryptedJson " # type: ignore[attr-defined]
|
||||
f"but is_json={is_json}"
|
||||
)
|
||||
|
||||
|
||||
class TestRotateCredential:
|
||||
"""Test rotation against the real Credential table (EncryptedJson).
|
||||
|
||||
Discovery is scoped to only the Credential model to avoid mutating
|
||||
other tables in the test database.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _limit_discovery(self) -> Generator[None, None, None]:
|
||||
with patch(
|
||||
f"{ROTATE_MODULE}._discover_encrypted_columns",
|
||||
return_value=[(Credential, "credential_json", ["id"], True)],
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture()
|
||||
def credential_id(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> Generator[int, None, None]:
|
||||
"""Insert a Credential row with raw encrypted bytes, clean up after."""
|
||||
config = {"api_key": "sk-test-1234", "endpoint": "https://example.com"}
|
||||
encrypted = _encrypt_string(json.dumps(config), key=OLD_KEY)
|
||||
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO credential "
|
||||
"(source, credential_json, admin_public, curator_public) "
|
||||
"VALUES (:source, :cred_json, true, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{"source": DocumentSource.INGESTION_API.value, "cred_json": encrypted},
|
||||
)
|
||||
cred_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
yield cred_id
|
||||
|
||||
db_session.execute(
|
||||
text("DELETE FROM credential WHERE id = :id"), {"id": cred_id}
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
def test_rotates_credential_json(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
assert totals.get("credential.credential_json", 0) >= 1
|
||||
|
||||
raw = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw is not None
|
||||
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
|
||||
assert decrypted["api_key"] == "sk-test-1234"
|
||||
assert decrypted["endpoint"] == "https://example.com"
|
||||
|
||||
def test_skips_already_rotated(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
_ = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
raw = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw is not None
|
||||
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
|
||||
assert decrypted["api_key"] == "sk-test-1234"
|
||||
|
||||
def test_dry_run_does_not_modify(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
original = _raw_credential_bytes(db_session, credential_id)
|
||||
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY, dry_run=True)
|
||||
|
||||
assert totals.get("credential.credential_json", 0) >= 1
|
||||
|
||||
raw_after = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw_after == original
|
||||
|
||||
|
||||
class TestRotateInternetSearchProvider:
|
||||
"""Test rotation against the real InternetSearchProvider table (EncryptedString).
|
||||
|
||||
Discovery is scoped to only the InternetSearchProvider model to avoid
|
||||
mutating other tables in the test database.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _limit_discovery(self) -> Generator[None, None, None]:
|
||||
with patch(
|
||||
f"{ROTATE_MODULE}._discover_encrypted_columns",
|
||||
return_value=[
|
||||
(InternetSearchProvider, "api_key", ["id"], False),
|
||||
],
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture()
|
||||
def isp_id(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> Generator[int, None, None]:
|
||||
"""Insert an InternetSearchProvider row with raw encrypted bytes."""
|
||||
encrypted = _encrypt_string("sk-secret-api-key", key=OLD_KEY)
|
||||
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO internet_search_provider "
|
||||
"(name, provider_type, api_key, is_active) "
|
||||
"VALUES (:name, :ptype, :api_key, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{
|
||||
"name": f"test-rotation-{id(self)}",
|
||||
"ptype": "test",
|
||||
"api_key": encrypted,
|
||||
},
|
||||
)
|
||||
isp_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
yield isp_id
|
||||
|
||||
db_session.execute(
|
||||
text("DELETE FROM internet_search_provider WHERE id = :id"),
|
||||
{"id": isp_id},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
def test_rotates_api_key(self, db_session: Session, isp_id: int) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
assert totals.get("internet_search_provider.api_key", 0) >= 1
|
||||
|
||||
raw = _raw_isp_bytes(db_session, isp_id)
|
||||
assert raw is not None
|
||||
assert _decrypt_bytes(raw, key=NEW_KEY) == "sk-secret-api-key"
|
||||
|
||||
def test_rotates_from_unencrypted(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> None:
|
||||
"""Test rotating data that was stored without any encryption key."""
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO internet_search_provider "
|
||||
"(name, provider_type, api_key, is_active) "
|
||||
"VALUES (:name, :ptype, :api_key, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{
|
||||
"name": f"test-raw-{id(self)}",
|
||||
"ptype": "test",
|
||||
"api_key": b"raw-api-key",
|
||||
},
|
||||
)
|
||||
isp_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
try:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=None)
|
||||
|
||||
assert totals.get("internet_search_provider.api_key", 0) >= 1
|
||||
|
||||
raw = _raw_isp_bytes(db_session, isp_id)
|
||||
assert raw is not None
|
||||
assert _decrypt_bytes(raw, key=NEW_KEY) == "raw-api-key"
|
||||
finally:
|
||||
db_session.execute(
|
||||
text("DELETE FROM internet_search_provider WHERE id = :id"),
|
||||
{"id": isp_id},
|
||||
)
|
||||
db_session.commit()
|
||||
@@ -85,7 +85,7 @@ def test_group_overlap_filter(
|
||||
results = _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
user_email="",
|
||||
user_email=None,
|
||||
external_group_ids=["group_engineering"],
|
||||
)
|
||||
result_ids = {n.raw_node_id for n in results}
|
||||
@@ -124,7 +124,7 @@ def test_no_credentials_returns_only_public(
|
||||
results = _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
user_email="",
|
||||
user_email=None,
|
||||
external_group_ids=[],
|
||||
)
|
||||
result_ids = {n.raw_node_id for n in results}
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
"""Tests that SlackBot CRUD operations return properly typed SensitiveValue fields.
|
||||
|
||||
Regression test for the bug where insert_slack_bot/update_slack_bot returned
|
||||
objects with raw string tokens instead of SensitiveValue wrappers, causing
|
||||
'str object has no attribute get_value' errors in SlackBot.from_model().
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.slack_bot import insert_slack_bot
|
||||
from onyx.db.slack_bot import update_slack_bot
|
||||
from onyx.server.manage.models import SlackBot
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
|
||||
|
||||
def _unique(prefix: str) -> str:
|
||||
return f"{prefix}-{uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def test_insert_slack_bot_returns_sensitive_values(db_session: Session) -> None:
|
||||
bot_token = _unique("xoxb-insert")
|
||||
app_token = _unique("xapp-insert")
|
||||
user_token = _unique("xoxp-insert")
|
||||
|
||||
slack_bot = insert_slack_bot(
|
||||
db_session=db_session,
|
||||
name=_unique("test-bot-insert"),
|
||||
enabled=True,
|
||||
bot_token=bot_token,
|
||||
app_token=app_token,
|
||||
user_token=user_token,
|
||||
)
|
||||
|
||||
assert isinstance(slack_bot.bot_token, SensitiveValue)
|
||||
assert isinstance(slack_bot.app_token, SensitiveValue)
|
||||
assert isinstance(slack_bot.user_token, SensitiveValue)
|
||||
|
||||
assert slack_bot.bot_token.get_value(apply_mask=False) == bot_token
|
||||
assert slack_bot.app_token.get_value(apply_mask=False) == app_token
|
||||
assert slack_bot.user_token.get_value(apply_mask=False) == user_token
|
||||
|
||||
# Verify from_model works without error
|
||||
pydantic_bot = SlackBot.from_model(slack_bot)
|
||||
assert pydantic_bot.bot_token # masked, but not empty
|
||||
assert pydantic_bot.app_token
|
||||
|
||||
|
||||
def test_update_slack_bot_returns_sensitive_values(db_session: Session) -> None:
|
||||
slack_bot = insert_slack_bot(
|
||||
db_session=db_session,
|
||||
name=_unique("test-bot-update"),
|
||||
enabled=True,
|
||||
bot_token=_unique("xoxb-update"),
|
||||
app_token=_unique("xapp-update"),
|
||||
)
|
||||
|
||||
new_bot_token = _unique("xoxb-update-new")
|
||||
new_app_token = _unique("xapp-update-new")
|
||||
new_user_token = _unique("xoxp-update-new")
|
||||
|
||||
updated = update_slack_bot(
|
||||
db_session=db_session,
|
||||
slack_bot_id=slack_bot.id,
|
||||
name=_unique("test-bot-updated"),
|
||||
enabled=False,
|
||||
bot_token=new_bot_token,
|
||||
app_token=new_app_token,
|
||||
user_token=new_user_token,
|
||||
)
|
||||
|
||||
assert isinstance(updated.bot_token, SensitiveValue)
|
||||
assert isinstance(updated.app_token, SensitiveValue)
|
||||
assert isinstance(updated.user_token, SensitiveValue)
|
||||
|
||||
assert updated.bot_token.get_value(apply_mask=False) == new_bot_token
|
||||
assert updated.app_token.get_value(apply_mask=False) == new_app_token
|
||||
assert updated.user_token.get_value(apply_mask=False) == new_user_token
|
||||
|
||||
# Verify from_model works without error
|
||||
pydantic_bot = SlackBot.from_model(updated)
|
||||
assert pydantic_bot.bot_token
|
||||
assert pydantic_bot.app_token
|
||||
assert pydantic_bot.user_token is not None
|
||||
@@ -148,16 +148,8 @@ class TestOAuthConfigCRUD:
|
||||
)
|
||||
|
||||
# Secrets should be preserved
|
||||
assert updated_config.client_id is not None
|
||||
assert original_client_id is not None
|
||||
assert updated_config.client_id.get_value(
|
||||
apply_mask=False
|
||||
) == original_client_id.get_value(apply_mask=False)
|
||||
assert updated_config.client_secret is not None
|
||||
assert original_client_secret is not None
|
||||
assert updated_config.client_secret.get_value(
|
||||
apply_mask=False
|
||||
) == original_client_secret.get_value(apply_mask=False)
|
||||
assert updated_config.client_id == original_client_id
|
||||
assert updated_config.client_secret == original_client_secret
|
||||
# But name should be updated
|
||||
assert updated_config.name == new_name
|
||||
|
||||
@@ -181,14 +173,9 @@ class TestOAuthConfigCRUD:
|
||||
)
|
||||
|
||||
# client_id should be cleared (empty string)
|
||||
assert updated_config.client_id is not None
|
||||
assert updated_config.client_id.get_value(apply_mask=False) == ""
|
||||
assert updated_config.client_id == ""
|
||||
# client_secret should be preserved
|
||||
assert updated_config.client_secret is not None
|
||||
assert original_client_secret is not None
|
||||
assert updated_config.client_secret.get_value(
|
||||
apply_mask=False
|
||||
) == original_client_secret.get_value(apply_mask=False)
|
||||
assert updated_config.client_secret == original_client_secret
|
||||
|
||||
def test_update_oauth_config_clear_client_secret(self, db_session: Session) -> None:
|
||||
"""Test clearing client_secret while preserving client_id"""
|
||||
@@ -203,14 +190,9 @@ class TestOAuthConfigCRUD:
|
||||
)
|
||||
|
||||
# client_secret should be cleared (empty string)
|
||||
assert updated_config.client_secret is not None
|
||||
assert updated_config.client_secret.get_value(apply_mask=False) == ""
|
||||
assert updated_config.client_secret == ""
|
||||
# client_id should be preserved
|
||||
assert updated_config.client_id is not None
|
||||
assert original_client_id is not None
|
||||
assert updated_config.client_id.get_value(
|
||||
apply_mask=False
|
||||
) == original_client_id.get_value(apply_mask=False)
|
||||
assert updated_config.client_id == original_client_id
|
||||
|
||||
def test_update_oauth_config_clear_both_secrets(self, db_session: Session) -> None:
|
||||
"""Test clearing both client_id and client_secret"""
|
||||
@@ -225,10 +207,8 @@ class TestOAuthConfigCRUD:
|
||||
)
|
||||
|
||||
# Both should be cleared (empty strings)
|
||||
assert updated_config.client_id is not None
|
||||
assert updated_config.client_id.get_value(apply_mask=False) == ""
|
||||
assert updated_config.client_secret is not None
|
||||
assert updated_config.client_secret.get_value(apply_mask=False) == ""
|
||||
assert updated_config.client_id == ""
|
||||
assert updated_config.client_secret == ""
|
||||
|
||||
def test_update_oauth_config_authorization_url(self, db_session: Session) -> None:
|
||||
"""Test updating authorization_url"""
|
||||
@@ -295,8 +275,7 @@ class TestOAuthConfigCRUD:
|
||||
assert updated_config.token_url == new_token_url
|
||||
assert updated_config.scopes == new_scopes
|
||||
assert updated_config.additional_params == new_params
|
||||
assert updated_config.client_id is not None
|
||||
assert updated_config.client_id.get_value(apply_mask=False) == new_client_id
|
||||
assert updated_config.client_id == new_client_id
|
||||
|
||||
def test_delete_oauth_config(self, db_session: Session) -> None:
|
||||
"""Test deleting an OAuth configuration"""
|
||||
@@ -437,8 +416,7 @@ class TestOAuthUserTokenCRUD:
|
||||
assert user_token.id is not None
|
||||
assert user_token.oauth_config_id == oauth_config.id
|
||||
assert user_token.user_id == user.id
|
||||
assert user_token.token_data is not None
|
||||
assert user_token.token_data.get_value(apply_mask=False) == token_data
|
||||
assert user_token.token_data == token_data
|
||||
assert user_token.created_at is not None
|
||||
assert user_token.updated_at is not None
|
||||
|
||||
@@ -468,13 +446,8 @@ class TestOAuthUserTokenCRUD:
|
||||
|
||||
# Should be the same token record (updated, not inserted)
|
||||
assert updated_token.id == initial_token_id
|
||||
assert updated_token.token_data is not None
|
||||
assert (
|
||||
updated_token.token_data.get_value(apply_mask=False) == updated_token_data
|
||||
)
|
||||
assert (
|
||||
updated_token.token_data.get_value(apply_mask=False) != initial_token_data
|
||||
)
|
||||
assert updated_token.token_data == updated_token_data
|
||||
assert updated_token.token_data != initial_token_data
|
||||
|
||||
def test_get_user_oauth_token(self, db_session: Session) -> None:
|
||||
"""Test retrieving a user's OAuth token"""
|
||||
@@ -490,8 +463,7 @@ class TestOAuthUserTokenCRUD:
|
||||
|
||||
assert retrieved_token is not None
|
||||
assert retrieved_token.id == created_token.id
|
||||
assert retrieved_token.token_data is not None
|
||||
assert retrieved_token.token_data.get_value(apply_mask=False) == token_data
|
||||
assert retrieved_token.token_data == token_data
|
||||
|
||||
def test_get_user_oauth_token_not_found(self, db_session: Session) -> None:
|
||||
"""Test retrieving a non-existent user token returns None"""
|
||||
@@ -547,8 +519,7 @@ class TestOAuthUserTokenCRUD:
|
||||
retrieved_token = get_user_oauth_token(oauth_config.id, user.id, db_session)
|
||||
assert retrieved_token is not None
|
||||
assert retrieved_token.id == updated_token.id
|
||||
assert retrieved_token.token_data is not None
|
||||
assert retrieved_token.token_data.get_value(apply_mask=False) == token_data2
|
||||
assert retrieved_token.token_data == token_data2
|
||||
|
||||
def test_cascade_delete_user_tokens_on_config_deletion(
|
||||
self, db_session: Session
|
||||
|
||||
@@ -374,14 +374,8 @@ class TestOAuthTokenManagerCodeExchange:
|
||||
assert call_args[0][0] == oauth_config.token_url
|
||||
assert call_args[1]["data"]["grant_type"] == "authorization_code"
|
||||
assert call_args[1]["data"]["code"] == "auth_code_123"
|
||||
assert oauth_config.client_id is not None
|
||||
assert oauth_config.client_secret is not None
|
||||
assert call_args[1]["data"]["client_id"] == oauth_config.client_id.get_value(
|
||||
apply_mask=False
|
||||
)
|
||||
assert call_args[1]["data"][
|
||||
"client_secret"
|
||||
] == oauth_config.client_secret.get_value(apply_mask=False)
|
||||
assert call_args[1]["data"]["client_id"] == oauth_config.client_id
|
||||
assert call_args[1]["data"]["client_secret"] == oauth_config.client_secret
|
||||
assert call_args[1]["data"]["redirect_uri"] == "https://example.com/callback"
|
||||
|
||||
@patch("onyx.auth.oauth_token_manager.requests.post")
|
||||
|
||||
@@ -950,7 +950,6 @@ from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from tests.external_dependency_unit.answer.stream_test_builder import StreamTestBuilder
|
||||
from tests.external_dependency_unit.answer.stream_test_utils import create_chat_session
|
||||
@@ -1292,21 +1291,12 @@ def test_code_interpreter_replay_packets_include_code_and_output(
|
||||
tool_call_id="call_replay_test",
|
||||
tool_call_argument_tokens=[json.dumps({"code": code})],
|
||||
)
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
obj=ToolCallArgumentDelta(
|
||||
tool_type="python",
|
||||
argument_deltas={"code": code},
|
||||
),
|
||||
),
|
||||
forward=2,
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
obj=PythonToolStart(code=code),
|
||||
),
|
||||
forward=False,
|
||||
forward=2,
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
|
||||
@@ -64,8 +64,7 @@ class TestBotConfigAPI:
|
||||
db_session.commit()
|
||||
|
||||
assert config is not None
|
||||
assert config.bot_token is not None
|
||||
assert config.bot_token.get_value(apply_mask=False) == "test_token_123"
|
||||
assert config.bot_token == "test_token_123"
|
||||
|
||||
# Cleanup
|
||||
delete_discord_bot_config(db_session)
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
"""Tests for EE AES-CBC encryption/decryption with explicit key support.
|
||||
|
||||
With EE mode enabled (via conftest), fetch_versioned_implementation resolves
|
||||
to the EE implementations, so no patching of the MIT layer is needed.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.utils.encryption import _decrypt_bytes
|
||||
from ee.onyx.utils.encryption import _encrypt_string
|
||||
from ee.onyx.utils.encryption import _get_trimmed_key
|
||||
from ee.onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from ee.onyx.utils.encryption import encrypt_string_to_bytes
|
||||
|
||||
EE_MODULE = "ee.onyx.utils.encryption"
|
||||
|
||||
# Keys must be exactly 16, 24, or 32 bytes for AES
|
||||
KEY_16 = "a" * 16
|
||||
KEY_16_ALT = "b" * 16
|
||||
KEY_24 = "d" * 24
|
||||
KEY_32 = "c" * 32
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_key_cache() -> None:
|
||||
_get_trimmed_key.cache_clear()
|
||||
|
||||
|
||||
class TestEncryptDecryptRoundTrip:
|
||||
def test_roundtrip_with_env_key(self) -> None:
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
|
||||
encrypted = _encrypt_string("hello world")
|
||||
assert encrypted != b"hello world"
|
||||
assert _decrypt_bytes(encrypted) == "hello world"
|
||||
|
||||
def test_roundtrip_with_explicit_key(self) -> None:
|
||||
encrypted = _encrypt_string("secret data", key=KEY_32)
|
||||
assert encrypted != b"secret data"
|
||||
assert _decrypt_bytes(encrypted, key=KEY_32) == "secret data"
|
||||
|
||||
def test_roundtrip_no_key(self) -> None:
|
||||
"""Without any key, data is raw-encoded (no encryption)."""
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", ""):
|
||||
encrypted = _encrypt_string("plain text")
|
||||
assert encrypted == b"plain text"
|
||||
assert _decrypt_bytes(encrypted) == "plain text"
|
||||
|
||||
def test_explicit_key_overrides_env(self) -> None:
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
|
||||
encrypted = _encrypt_string("data", key=KEY_16_ALT)
|
||||
with pytest.raises(ValueError):
|
||||
_decrypt_bytes(encrypted, key=KEY_16)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16_ALT) == "data"
|
||||
|
||||
def test_different_encryptions_produce_different_bytes(self) -> None:
|
||||
"""Each encryption uses a random IV, so results differ."""
|
||||
a = _encrypt_string("same", key=KEY_16)
|
||||
b = _encrypt_string("same", key=KEY_16)
|
||||
assert a != b
|
||||
|
||||
def test_roundtrip_empty_string(self) -> None:
|
||||
encrypted = _encrypt_string("", key=KEY_16)
|
||||
assert encrypted != b""
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16) == ""
|
||||
|
||||
def test_roundtrip_unicode(self) -> None:
|
||||
text = "日本語テスト 🔐 émojis"
|
||||
encrypted = _encrypt_string(text, key=KEY_16)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16) == text
|
||||
|
||||
|
||||
class TestDecryptFallbackBehavior:
|
||||
def test_wrong_env_key_falls_back_to_raw_decode(self) -> None:
|
||||
"""Default key path: AES fails on non-AES data → fallback to raw decode."""
|
||||
raw = "readable text".encode()
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
|
||||
assert _decrypt_bytes(raw) == "readable text"
|
||||
|
||||
def test_explicit_wrong_key_raises(self) -> None:
|
||||
"""Explicit key path: AES fails → raises, no fallback."""
|
||||
encrypted = _encrypt_string("secret", key=KEY_16)
|
||||
with pytest.raises(ValueError):
|
||||
_decrypt_bytes(encrypted, key=KEY_16_ALT)
|
||||
|
||||
def test_explicit_none_key_with_no_env(self) -> None:
|
||||
"""key=None with empty env → raw decode."""
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", ""):
|
||||
assert _decrypt_bytes(b"hello", key=None) == "hello"
|
||||
|
||||
def test_explicit_empty_string_key(self) -> None:
|
||||
"""key='' means no encryption."""
|
||||
encrypted = _encrypt_string("test", key="")
|
||||
assert encrypted == b"test"
|
||||
assert _decrypt_bytes(encrypted, key="") == "test"
|
||||
|
||||
|
||||
class TestKeyValidation:
|
||||
def test_key_too_short_raises(self) -> None:
|
||||
with pytest.raises(RuntimeError, match="too short"):
|
||||
_encrypt_string("data", key="short")
|
||||
|
||||
def test_16_byte_key(self) -> None:
|
||||
encrypted = _encrypt_string("data", key=KEY_16)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16) == "data"
|
||||
|
||||
def test_24_byte_key(self) -> None:
|
||||
encrypted = _encrypt_string("data", key=KEY_24)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_24) == "data"
|
||||
|
||||
def test_32_byte_key(self) -> None:
|
||||
encrypted = _encrypt_string("data", key=KEY_32)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_32) == "data"
|
||||
|
||||
def test_long_key_truncated_to_32(self) -> None:
|
||||
"""Keys longer than 32 bytes are truncated to 32."""
|
||||
long_key = "e" * 64
|
||||
encrypted = _encrypt_string("data", key=long_key)
|
||||
assert _decrypt_bytes(encrypted, key=long_key) == "data"
|
||||
|
||||
def test_20_byte_key_trimmed_to_16(self) -> None:
|
||||
"""A 20-byte key is trimmed to the largest valid AES size that fits (16)."""
|
||||
key_20 = "f" * 20
|
||||
encrypted = _encrypt_string("data", key=key_20)
|
||||
assert _decrypt_bytes(encrypted, key=key_20) == "data"
|
||||
|
||||
# Verify it was trimmed to 16 by checking that the first 16 bytes
|
||||
# of the key can also decrypt it
|
||||
key_16_same_prefix = "f" * 16
|
||||
assert _decrypt_bytes(encrypted, key=key_16_same_prefix) == "data"
|
||||
|
||||
def test_25_byte_key_trimmed_to_24(self) -> None:
|
||||
"""A 25-byte key is trimmed to the largest valid AES size that fits (24)."""
|
||||
key_25 = "g" * 25
|
||||
encrypted = _encrypt_string("data", key=key_25)
|
||||
assert _decrypt_bytes(encrypted, key=key_25) == "data"
|
||||
|
||||
key_24_same_prefix = "g" * 24
|
||||
assert _decrypt_bytes(encrypted, key=key_24_same_prefix) == "data"
|
||||
|
||||
def test_30_byte_key_trimmed_to_24(self) -> None:
|
||||
"""A 30-byte key is trimmed to the largest valid AES size that fits (24)."""
|
||||
key_30 = "h" * 30
|
||||
encrypted = _encrypt_string("data", key=key_30)
|
||||
assert _decrypt_bytes(encrypted, key=key_30) == "data"
|
||||
|
||||
key_24_same_prefix = "h" * 24
|
||||
assert _decrypt_bytes(encrypted, key=key_24_same_prefix) == "data"
|
||||
|
||||
|
||||
class TestWrapperFunctions:
|
||||
"""Test encrypt_string_to_bytes / decrypt_bytes_to_string pass key through.
|
||||
|
||||
With EE mode enabled, the wrappers resolve to EE implementations automatically.
|
||||
"""
|
||||
|
||||
def test_wrapper_passes_key(self) -> None:
|
||||
encrypted = encrypt_string_to_bytes("payload", key=KEY_16)
|
||||
assert decrypt_bytes_to_string(encrypted, key=KEY_16) == "payload"
|
||||
|
||||
def test_wrapper_no_key_uses_env(self) -> None:
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_32):
|
||||
encrypted = encrypt_string_to_bytes("payload")
|
||||
assert decrypt_bytes_to_string(encrypted) == "payload"
|
||||
@@ -1,163 +0,0 @@
|
||||
"""Tests for user file ACL computation, including shared persona access."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.access.access import collect_user_file_access
|
||||
from onyx.access.access import get_access_for_user_files_impl
|
||||
from onyx.access.utils import prefix_user_email
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
|
||||
|
||||
def _make_user(email: str) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.email = email
|
||||
user.id = uuid4()
|
||||
return user
|
||||
|
||||
|
||||
def _make_persona(
|
||||
*,
|
||||
owner: MagicMock | None = None,
|
||||
shared_users: list[MagicMock] | None = None,
|
||||
is_public: bool = False,
|
||||
deleted: bool = False,
|
||||
) -> MagicMock:
|
||||
persona = MagicMock()
|
||||
persona.deleted = deleted
|
||||
persona.is_public = is_public
|
||||
persona.user_id = owner.id if owner else None
|
||||
persona.user = owner
|
||||
persona.users = shared_users or []
|
||||
return persona
|
||||
|
||||
|
||||
def _make_user_file(
|
||||
*,
|
||||
owner: MagicMock,
|
||||
assistants: list[MagicMock] | None = None,
|
||||
) -> MagicMock:
|
||||
uf = MagicMock()
|
||||
uf.id = uuid4()
|
||||
uf.user = owner
|
||||
uf.user_id = owner.id
|
||||
uf.assistants = assistants or []
|
||||
return uf
|
||||
|
||||
|
||||
class TestCollectUserFileAccess:
|
||||
def test_owner_only(self) -> None:
|
||||
owner = _make_user("owner@test.com")
|
||||
uf = _make_user_file(owner=owner)
|
||||
|
||||
emails, is_public = collect_user_file_access(uf)
|
||||
|
||||
assert emails == {"owner@test.com"}
|
||||
assert is_public is False
|
||||
|
||||
def test_shared_persona_adds_users(self) -> None:
|
||||
owner = _make_user("owner@test.com")
|
||||
shared = _make_user("shared@test.com")
|
||||
persona = _make_persona(owner=owner, shared_users=[shared])
|
||||
uf = _make_user_file(owner=owner, assistants=[persona])
|
||||
|
||||
emails, is_public = collect_user_file_access(uf)
|
||||
|
||||
assert emails == {"owner@test.com", "shared@test.com"}
|
||||
assert is_public is False
|
||||
|
||||
def test_persona_owner_added(self) -> None:
|
||||
"""Persona owner (different from file owner) gets access too."""
|
||||
file_owner = _make_user("file-owner@test.com")
|
||||
persona_owner = _make_user("persona-owner@test.com")
|
||||
persona = _make_persona(owner=persona_owner)
|
||||
uf = _make_user_file(owner=file_owner, assistants=[persona])
|
||||
|
||||
emails, is_public = collect_user_file_access(uf)
|
||||
|
||||
assert "file-owner@test.com" in emails
|
||||
assert "persona-owner@test.com" in emails
|
||||
|
||||
def test_public_persona_makes_file_public(self) -> None:
|
||||
owner = _make_user("owner@test.com")
|
||||
persona = _make_persona(owner=owner, is_public=True)
|
||||
uf = _make_user_file(owner=owner, assistants=[persona])
|
||||
|
||||
emails, is_public = collect_user_file_access(uf)
|
||||
|
||||
assert is_public is True
|
||||
assert "owner@test.com" in emails
|
||||
|
||||
def test_deleted_persona_ignored(self) -> None:
|
||||
owner = _make_user("owner@test.com")
|
||||
shared = _make_user("shared@test.com")
|
||||
persona = _make_persona(owner=owner, shared_users=[shared], deleted=True)
|
||||
uf = _make_user_file(owner=owner, assistants=[persona])
|
||||
|
||||
emails, is_public = collect_user_file_access(uf)
|
||||
|
||||
assert emails == {"owner@test.com"}
|
||||
assert is_public is False
|
||||
|
||||
def test_multiple_personas_combine(self) -> None:
|
||||
owner = _make_user("owner@test.com")
|
||||
user_a = _make_user("a@test.com")
|
||||
user_b = _make_user("b@test.com")
|
||||
p1 = _make_persona(owner=owner, shared_users=[user_a])
|
||||
p2 = _make_persona(owner=owner, shared_users=[user_b])
|
||||
uf = _make_user_file(owner=owner, assistants=[p1, p2])
|
||||
|
||||
emails, is_public = collect_user_file_access(uf)
|
||||
|
||||
assert emails == {"owner@test.com", "a@test.com", "b@test.com"}
|
||||
|
||||
def test_deduplication(self) -> None:
|
||||
owner = _make_user("owner@test.com")
|
||||
shared = _make_user("shared@test.com")
|
||||
p1 = _make_persona(owner=owner, shared_users=[shared])
|
||||
p2 = _make_persona(owner=owner, shared_users=[shared])
|
||||
uf = _make_user_file(owner=owner, assistants=[p1, p2])
|
||||
|
||||
emails, _ = collect_user_file_access(uf)
|
||||
|
||||
assert emails == {"owner@test.com", "shared@test.com"}
|
||||
|
||||
|
||||
class TestGetAccessForUserFiles:
|
||||
def test_shared_user_in_acl(self) -> None:
|
||||
"""Shared persona users should appear in the ACL."""
|
||||
owner = _make_user("owner@test.com")
|
||||
shared = _make_user("shared@test.com")
|
||||
persona = _make_persona(owner=owner, shared_users=[shared])
|
||||
uf = _make_user_file(owner=owner, assistants=[persona])
|
||||
|
||||
db_session = MagicMock()
|
||||
with patch(
|
||||
"onyx.access.access.fetch_user_files_with_access_relationships",
|
||||
return_value=[uf],
|
||||
):
|
||||
result = get_access_for_user_files_impl([str(uf.id)], db_session)
|
||||
|
||||
access = result[str(uf.id)]
|
||||
acl = access.to_acl()
|
||||
assert prefix_user_email("owner@test.com") in acl
|
||||
assert prefix_user_email("shared@test.com") in acl
|
||||
assert access.is_public is False
|
||||
|
||||
def test_public_persona_sets_public_acl(self) -> None:
|
||||
owner = _make_user("owner@test.com")
|
||||
persona = _make_persona(owner=owner, is_public=True)
|
||||
uf = _make_user_file(owner=owner, assistants=[persona])
|
||||
|
||||
db_session = MagicMock()
|
||||
with patch(
|
||||
"onyx.access.access.fetch_user_files_with_access_relationships",
|
||||
return_value=[uf],
|
||||
):
|
||||
result = get_access_for_user_files_impl([str(uf.id)], db_session)
|
||||
|
||||
access = result[str(uf.id)]
|
||||
assert access.is_public is True
|
||||
acl = access.to_acl()
|
||||
assert PUBLIC_DOC_PAT in acl
|
||||
@@ -1,54 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import onyx.auth.users as users
|
||||
from onyx.auth.users import verify_auth_setting
|
||||
from onyx.configs.constants import AuthType
|
||||
|
||||
|
||||
def test_verify_auth_setting_raises_for_cloud(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Cloud auth type is not valid for self-hosted deployments."""
|
||||
monkeypatch.setenv("AUTH_TYPE", "cloud")
|
||||
|
||||
with pytest.raises(ValueError, match="'cloud' is not a valid auth type"):
|
||||
verify_auth_setting()
|
||||
|
||||
|
||||
def test_verify_auth_setting_warns_for_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Disabled auth type logs a deprecation warning."""
|
||||
monkeypatch.setenv("AUTH_TYPE", "disabled")
|
||||
|
||||
mock_logger = MagicMock()
|
||||
monkeypatch.setattr(users, "logger", mock_logger)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.BASIC)
|
||||
|
||||
verify_auth_setting()
|
||||
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "no longer supported" in mock_logger.warning.call_args[0][0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"auth_type",
|
||||
[AuthType.BASIC, AuthType.GOOGLE_OAUTH, AuthType.OIDC, AuthType.SAML],
|
||||
)
|
||||
def test_verify_auth_setting_valid_auth_types(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
auth_type: AuthType,
|
||||
) -> None:
|
||||
"""Valid auth types work without errors or warnings."""
|
||||
monkeypatch.setenv("AUTH_TYPE", auth_type.value)
|
||||
|
||||
mock_logger = MagicMock()
|
||||
monkeypatch.setattr(users, "logger", mock_logger)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", auth_type)
|
||||
|
||||
verify_auth_setting()
|
||||
|
||||
mock_logger.warning.assert_not_called()
|
||||
mock_logger.notice.assert_called_once_with(f"Using Auth Type: {auth_type.value}")
|
||||
@@ -27,6 +27,7 @@ def _mock_session_returning_none() -> MagicMock:
|
||||
"""Return a mock session whose .get() returns None (file not found)."""
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
return session
|
||||
|
||||
|
||||
@@ -219,10 +220,6 @@ class TestDeleteUserFileImpl:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(
|
||||
f"{TASKS_MODULE}.fetch_user_files_with_access_relationships",
|
||||
return_value=[],
|
||||
)
|
||||
class TestProjectSyncUserFileImpl:
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
@@ -230,7 +227,6 @@ class TestProjectSyncUserFileImpl:
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
_mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
@@ -259,7 +255,6 @@ class TestProjectSyncUserFileImpl:
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
_mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
@@ -282,7 +277,6 @@ class TestProjectSyncUserFileImpl:
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
_mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
@@ -379,13 +379,10 @@ class TestProjectSyncImplNoVectorDb:
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.COMPLETED)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
with (
|
||||
patch(
|
||||
f"{TASKS_MODULE}.fetch_user_files_with_access_relationships",
|
||||
return_value=[uf],
|
||||
),
|
||||
patch(f"{TASKS_MODULE}.get_all_document_indices") as mock_get_indices,
|
||||
patch(f"{TASKS_MODULE}.get_active_search_settings") as mock_get_ss,
|
||||
patch(f"{TASKS_MODULE}.httpx_init_vespa_pool") as mock_vespa_pool,
|
||||
@@ -408,17 +405,14 @@ class TestProjectSyncImplNoVectorDb:
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.COMPLETED)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
with patch(
|
||||
f"{TASKS_MODULE}.fetch_user_files_with_access_relationships",
|
||||
return_value=[uf],
|
||||
):
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
assert uf.needs_project_sync is False
|
||||
assert uf.needs_persona_sync is False
|
||||
|
||||
@@ -1,630 +0,0 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _make_tool_call_delta(
|
||||
index: int = 0,
|
||||
name: str | None = None,
|
||||
arguments: str | None = None,
|
||||
function_is_none: bool = False,
|
||||
) -> MagicMock:
|
||||
"""Create a mock tool_call_delta matching the LiteLLM streaming shape."""
|
||||
delta = MagicMock()
|
||||
delta.index = index
|
||||
if function_is_none:
|
||||
delta.function = None
|
||||
else:
|
||||
delta.function = MagicMock()
|
||||
delta.function.name = name
|
||||
delta.function.arguments = arguments
|
||||
return delta
|
||||
|
||||
|
||||
def _make_placement() -> Placement:
|
||||
return Placement(turn_index=0, tab_index=0)
|
||||
|
||||
|
||||
def _mock_tool_class(emit: bool = True) -> MagicMock:
|
||||
cls = MagicMock()
|
||||
cls.should_emit_argument_deltas.return_value = emit
|
||||
return cls
|
||||
|
||||
|
||||
def _collect(
|
||||
tc_map: dict[int, dict[str, Any]],
|
||||
delta: MagicMock,
|
||||
placement: Placement | None = None,
|
||||
parsers: dict[int, Parser] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Run maybe_emit_argument_delta and return the yielded packets."""
|
||||
return list(
|
||||
maybe_emit_argument_delta(
|
||||
tc_map,
|
||||
delta,
|
||||
placement or _make_placement(),
|
||||
parsers if parsers is not None else {},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _stream_fragments(
|
||||
fragments: list[str],
|
||||
tc_map: dict[int, dict[str, Any]],
|
||||
placement: Placement | None = None,
|
||||
) -> list[str]:
|
||||
"""Feed fragments into maybe_emit_argument_delta one by one, returning
|
||||
all emitted content values concatenated per-key as a flat list."""
|
||||
pl = placement or _make_placement()
|
||||
parsers: dict[int, Parser] = {}
|
||||
emitted: list[str] = []
|
||||
for frag in fragments:
|
||||
tc_map[0]["arguments"] += frag
|
||||
delta = _make_tool_call_delta(arguments=frag)
|
||||
for packet in maybe_emit_argument_delta(tc_map, delta, pl, parsers=parsers):
|
||||
obj = packet.obj
|
||||
assert isinstance(obj, ToolCallArgumentDelta)
|
||||
for value in obj.argument_deltas.values():
|
||||
emitted.append(value)
|
||||
return emitted
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaGuards:
|
||||
"""Tests for conditions that cause no packet to be emitted."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_tool_does_not_opt_in(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""Tools that return False from should_emit_argument_deltas emit nothing."""
|
||||
mock_get_tool.return_value = _mock_tool_class(emit=False)
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_tool_class_unknown(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
mock_get_tool.return_value = None
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "unknown", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_no_argument_fragment(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments=None)) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_key_value_not_started(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""Key exists in JSON but its string value hasn't begun yet."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code":'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments=":")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_before_any_key(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Only the opening brace has arrived — no key to stream yet."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": "{"}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="{")) == []
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaBasic:
|
||||
"""Tests for correct packet content and incremental emission."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_emits_packet_with_correct_fields(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "', "print(1)", '"}']
|
||||
|
||||
pl = _make_placement()
|
||||
parsers: dict[int, Parser] = {}
|
||||
all_packets = []
|
||||
for frag in fragments:
|
||||
tc_map[0]["arguments"] += frag
|
||||
packets = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments=frag), pl, parsers
|
||||
)
|
||||
all_packets.extend(packets)
|
||||
|
||||
assert len(all_packets) >= 1
|
||||
# Verify packet structure
|
||||
obj = all_packets[0].obj
|
||||
assert isinstance(obj, ToolCallArgumentDelta)
|
||||
assert obj.tool_type == "python"
|
||||
# All emitted content should reconstruct the value
|
||||
full_code = ""
|
||||
for p in all_packets:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
if "code" in p.obj.argument_deltas:
|
||||
full_code += p.obj.argument_deltas["code"]
|
||||
assert full_code == "print(1)"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_emits_only_new_content_on_subsequent_call(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""After a first emission, subsequent calls emit only the diff."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
parsers: dict[int, Parser] = {}
|
||||
pl = _make_placement()
|
||||
|
||||
# First fragment opens the string
|
||||
tc_map[0]["arguments"] = '{"code": "abc'
|
||||
packets_1 = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments='{"code": "abc'), pl, parsers
|
||||
)
|
||||
code_1 = ""
|
||||
for p in packets_1:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_1 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_1 == "abc"
|
||||
|
||||
# Second fragment appends more
|
||||
tc_map[0]["arguments"] = '{"code": "abcdef'
|
||||
packets_2 = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments="def"), pl, parsers
|
||||
)
|
||||
code_2 = ""
|
||||
for p in packets_2:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_2 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_2 == "def"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_handles_multiple_keys_sequentially(self, mock_get_tool: MagicMock) -> None:
|
||||
"""When a second key starts, emissions switch to that key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "x',
|
||||
'", "output": "hello',
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert "x" in full
|
||||
assert "hello" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_spans_key_boundary(self, mock_get_tool: MagicMock) -> None:
|
||||
"""A single delta contains the end of one value and the start of the next key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "x',
|
||||
'y", "lang": "py',
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert "xy" in full
|
||||
assert "py" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_empty_value_emits_nothing(self, mock_get_tool: MagicMock) -> None:
|
||||
"""An empty string value has nothing to emit."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
# Opening quote just arrived, value is empty
|
||||
tc_map[0]["arguments"] = '{"code": "'
|
||||
packets = _collect(tc_map, _make_tool_call_delta(arguments='{"code": "'))
|
||||
# No string content yet, so either no packet or empty deltas
|
||||
for p in packets:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
assert p.obj.argument_deltas.get("code", "") == ""
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaDecoding:
|
||||
"""Tests verifying that JSON escape sequences are properly decoded."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_newlines(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "line1\\nline2"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "line1\nline2"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_tabs(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "\\tindented"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "\tindented"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_escaped_quotes(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "say \\"hi\\""}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == 'say "hi"'
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_escaped_backslashes(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "path\\\\dir"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "path\\dir"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_unicode_escape(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "\\u0041"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "A"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_incomplete_escape_at_end_decoded_on_next_chunk(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""A trailing backslash (incomplete escape) is completed in the next chunk."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "hello\\', 'n"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "hello\n"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_incomplete_unicode_escape_completed_on_next_chunk(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""A partial \\uXX sequence is completed in the next chunk."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "hello\\u00', '41"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "helloA"
|
||||
|
||||
|
||||
class TestArgumentDeltaStreamingE2E:
|
||||
"""Simulates realistic sequences of LLM argument deltas to verify
|
||||
the full pipeline produces correct decoded output."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_realistic_python_code_streaming(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams: {"code": "print('hello')\\nprint('world')"}"""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"',
|
||||
"code",
|
||||
'": "',
|
||||
"print(",
|
||||
"'hello')",
|
||||
"\\n",
|
||||
"print(",
|
||||
"'world')",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "print('hello')\nprint('world')"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_streaming_with_tabs_and_newlines(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams code with tabs and newlines."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"if True:",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"pass",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "if True:\n\tpass"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_split_escape_sequence(self, mock_get_tool: MagicMock) -> None:
|
||||
"""An escape sequence split across two fragments (backslash in one,
|
||||
'n' in the next) should still decode correctly."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "hello',
|
||||
"\\",
|
||||
"n",
|
||||
'world"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "hello\nworld"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_multiple_newlines_and_indentation(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams a multi-line function with multiple escape sequences."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"def foo():",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"x = 1",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"return x",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "def foo():\n\tx = 1\n\treturn x"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_two_keys_streamed_sequentially(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams code first, then a second key (language) — both decoded."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"x = 1",
|
||||
'", "language": "',
|
||||
"python",
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
# Should have emissions for both keys
|
||||
full = "".join(emitted)
|
||||
assert "x = 1" in full
|
||||
assert "python" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_code_containing_dict_literal(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Python code like `x = {"key": "val"}` contains JSON-like patterns.
|
||||
The escaped quotes inside the *outer* JSON value should prevent the
|
||||
inner `"key":` from being mistaken for a top-level JSON key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
# The LLM sends: {"code": "x = {\"key\": \"val\"}"}
|
||||
# The inner quotes are escaped as \" in the JSON value.
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"x = {",
|
||||
'\\"key\\"',
|
||||
": ",
|
||||
'\\"val\\"',
|
||||
"}",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == 'x = {"key": "val"}'
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_code_with_colon_in_value(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Colons inside the string value should not confuse key detection."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"url = ",
|
||||
'\\"https://example.com\\"',
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == 'url = "https://example.com"'
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaEdgeCases:
|
||||
"""Edge cases not covered by the standard test classes."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_function_is_none(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Some delta chunks have function=None (e.g. role-only deltas)."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
delta = _make_tool_call_delta(arguments=None, function_is_none=True)
|
||||
assert _collect(tc_map, delta) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_multiple_concurrent_tool_calls(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Two tool calls streaming at different indices in parallel."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""},
|
||||
1: {"id": "tc_2", "name": "python", "arguments": ""},
|
||||
}
|
||||
|
||||
parsers: dict[int, Parser] = {}
|
||||
pl = _make_placement()
|
||||
|
||||
# Feed full JSON to index 0
|
||||
tc_map[0]["arguments"] = '{"code": "aaa"}'
|
||||
packets_0 = _collect(
|
||||
tc_map,
|
||||
_make_tool_call_delta(index=0, arguments='{"code": "aaa"}'),
|
||||
pl,
|
||||
parsers,
|
||||
)
|
||||
code_0 = ""
|
||||
for p in packets_0:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_0 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_0 == "aaa"
|
||||
|
||||
# Feed full JSON to index 1
|
||||
tc_map[1]["arguments"] = '{"code": "bbb"}'
|
||||
packets_1 = _collect(
|
||||
tc_map,
|
||||
_make_tool_call_delta(index=1, arguments='{"code": "bbb"}'),
|
||||
pl,
|
||||
parsers,
|
||||
)
|
||||
code_1 = ""
|
||||
for p in packets_1:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_1 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_1 == "bbb"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_with_four_arguments(self, mock_get_tool: MagicMock) -> None:
|
||||
"""A single delta contains four complete key-value pairs."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
full = '{"a": "one", "b": "two", "c": "three", "d": "four"}'
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
tc_map[0]["arguments"] = full
|
||||
parsers: dict[int, Parser] = {}
|
||||
packets = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments=full), parsers=parsers
|
||||
)
|
||||
|
||||
# Collect all argument deltas across packets
|
||||
all_deltas: dict[str, str] = {}
|
||||
for p in packets:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
for k, v in p.obj.argument_deltas.items():
|
||||
all_deltas[k] = all_deltas.get(k, "") + v
|
||||
|
||||
assert all_deltas == {
|
||||
"a": "one",
|
||||
"b": "two",
|
||||
"c": "three",
|
||||
"d": "four",
|
||||
}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_on_second_arg_after_first_complete(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""First argument is fully complete; delta only adds to the second."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
|
||||
fragments = [
|
||||
'{"code": "print(1)", "lang": "py',
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert "print(1)" in full
|
||||
assert "py" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_non_string_values_skipped(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Non-string values (numbers, booleans, null) are skipped — they are
|
||||
available in the final tool-call kickoff packet. String arguments
|
||||
following them are still emitted."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"timeout": 30, "code": "hello"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert full == "hello"
|
||||
@@ -0,0 +1,325 @@
|
||||
"""Unit tests for SharepointConnector._fetch_site_pages error handling.
|
||||
|
||||
Covers 404 handling (classic sites / no modern pages) and 400
|
||||
canvasLayout fallback (corrupt pages causing $expand=canvasLayout to
|
||||
fail on the LIST endpoint).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from requests import Response
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from onyx.connectors.sharepoint.connector import GRAPH_INVALID_REQUEST_CODE
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.connectors.sharepoint.connector import SiteDescriptor
|
||||
|
||||
SITE_URL = "https://tenant.sharepoint.com/sites/ClassicSite"
|
||||
FAKE_SITE_ID = "tenant.sharepoint.com,abc123,def456"
|
||||
PAGES_COLLECTION = f"https://graph.microsoft.com/v1.0/sites/{FAKE_SITE_ID}/pages"
|
||||
SITE_PAGES_BASE = f"{PAGES_COLLECTION}/microsoft.graph.sitePage"
|
||||
|
||||
|
||||
def _site_descriptor() -> SiteDescriptor:
|
||||
return SiteDescriptor(url=SITE_URL, drive_name=None, folder_path=None)
|
||||
|
||||
|
||||
def _make_http_error(
|
||||
status_code: int,
|
||||
error_code: str = "itemNotFound",
|
||||
message: str = "Item not found",
|
||||
) -> HTTPError:
|
||||
body = {"error": {"code": error_code, "message": message}}
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response._content = json.dumps(body).encode()
|
||||
response.headers["Content-Type"] = "application/json"
|
||||
return HTTPError(response=response)
|
||||
|
||||
|
||||
def _setup_connector(
|
||||
monkeypatch: pytest.MonkeyPatch, # noqa: ARG001
|
||||
) -> SharepointConnector:
|
||||
"""Create a connector with the graph client and site resolution mocked."""
|
||||
connector = SharepointConnector(sites=[SITE_URL])
|
||||
connector.graph_api_base = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
mock_sites = type(
|
||||
"FakeSites",
|
||||
(),
|
||||
{
|
||||
"get_by_url": staticmethod(
|
||||
lambda url: type( # noqa: ARG005
|
||||
"Q",
|
||||
(),
|
||||
{
|
||||
"execute_query": lambda self: None, # noqa: ARG005
|
||||
"id": FAKE_SITE_ID,
|
||||
},
|
||||
)()
|
||||
),
|
||||
},
|
||||
)()
|
||||
connector._graph_client = type("FakeGraphClient", (), {"sites": mock_sites})()
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
def _patch_graph_api_get_json(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
fake_fn: Any,
|
||||
) -> None:
|
||||
monkeypatch.setattr(SharepointConnector, "_graph_api_get_json", fake_fn)
|
||||
|
||||
|
||||
class TestFetchSitePages404:
|
||||
def test_404_yields_no_pages(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""A 404 from the Pages API should result in zero yielded pages."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
raise _make_http_error(404)
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
assert pages == []
|
||||
|
||||
def test_404_does_not_raise(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""A 404 must not propagate as an exception."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
raise _make_http_error(404)
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
for _ in connector._fetch_site_pages(_site_descriptor()):
|
||||
pass
|
||||
|
||||
def test_non_404_http_error_still_raises(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Non-404 HTTP errors (e.g. 403) must still propagate."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
raise _make_http_error(403)
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
with pytest.raises(HTTPError):
|
||||
list(connector._fetch_site_pages(_site_descriptor()))
|
||||
|
||||
def test_successful_fetch_yields_pages(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""When the API succeeds, pages should be yielded normally."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
fake_page = {
|
||||
"id": "page-1",
|
||||
"title": "Hello World",
|
||||
"webUrl": f"{SITE_URL}/SitePages/Hello.aspx",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
return {"value": [fake_page]}
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
assert len(pages) == 1
|
||||
assert pages[0]["id"] == "page-1"
|
||||
|
||||
def test_404_on_second_page_stops_pagination(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""If the first API page succeeds but a nextLink returns 404,
|
||||
already-yielded pages are kept and iteration stops cleanly."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
call_count = 0
|
||||
first_page = {
|
||||
"id": "page-1",
|
||||
"title": "First",
|
||||
"webUrl": f"{SITE_URL}/SitePages/First.aspx",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return {
|
||||
"value": [first_page],
|
||||
"@odata.nextLink": "https://graph.microsoft.com/next",
|
||||
}
|
||||
raise _make_http_error(404)
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
assert len(pages) == 1
|
||||
assert pages[0]["id"] == "page-1"
|
||||
|
||||
|
||||
class TestFetchSitePages400Fallback:
|
||||
"""When $expand=canvasLayout on the LIST endpoint returns 400
|
||||
invalidRequest, _fetch_site_pages should fall back to listing
|
||||
without expansion, then expanding each page individually."""
|
||||
|
||||
GOOD_PAGE: dict[str, Any] = {
|
||||
"id": "good-1",
|
||||
"name": "Good.aspx",
|
||||
"title": "Good Page",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
BAD_PAGE: dict[str, Any] = {
|
||||
"id": "bad-1",
|
||||
"name": "Bad.aspx",
|
||||
"title": "Bad Page",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
GOOD_PAGE_EXPANDED: dict[str, Any] = {
|
||||
**GOOD_PAGE,
|
||||
"canvasLayout": {"horizontalSections": []},
|
||||
}
|
||||
|
||||
def test_fallback_expands_good_pages_individually(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""On 400 from the LIST expand, the connector should list without
|
||||
expand, then GET each page individually with $expand=canvasLayout."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
good_page = self.GOOD_PAGE
|
||||
bad_page = self.BAD_PAGE
|
||||
good_page_expanded = self.GOOD_PAGE_EXPANDED
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str,
|
||||
params: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if url == SITE_PAGES_BASE and params == {"$expand": "canvasLayout"}:
|
||||
raise _make_http_error(
|
||||
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
|
||||
)
|
||||
if url == SITE_PAGES_BASE and params is None:
|
||||
return {"value": [good_page, bad_page]}
|
||||
expand_params = {"$expand": "canvasLayout"}
|
||||
if url == f"{PAGES_COLLECTION}/good-1/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
return good_page_expanded
|
||||
if url == f"{PAGES_COLLECTION}/bad-1/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
raise _make_http_error(
|
||||
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
|
||||
)
|
||||
raise AssertionError(f"Unexpected call: {url} {params}")
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
|
||||
assert len(pages) == 2
|
||||
assert pages[0].get("canvasLayout") is not None
|
||||
assert pages[1].get("canvasLayout") is None
|
||||
assert pages[1]["id"] == "bad-1"
|
||||
|
||||
def test_mid_pagination_400_does_not_duplicate(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""If the first paginated batch succeeds but a later nextLink
|
||||
returns 400, pages from the first batch must not be re-yielded
|
||||
by the fallback."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
good_page = self.GOOD_PAGE
|
||||
good_page_expanded = self.GOOD_PAGE_EXPANDED
|
||||
bad_page = self.BAD_PAGE
|
||||
second_page = {
|
||||
"id": "page-2",
|
||||
"name": "Second.aspx",
|
||||
"title": "Second Page",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
next_link = "https://graph.microsoft.com/v1.0/next-page-link"
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str,
|
||||
params: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if url == SITE_PAGES_BASE and params == {"$expand": "canvasLayout"}:
|
||||
return {
|
||||
"value": [good_page],
|
||||
"@odata.nextLink": next_link,
|
||||
}
|
||||
if url == next_link:
|
||||
raise _make_http_error(
|
||||
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
|
||||
)
|
||||
if url == SITE_PAGES_BASE and params is None:
|
||||
return {"value": [good_page, bad_page, second_page]}
|
||||
expand_params = {"$expand": "canvasLayout"}
|
||||
if url == f"{PAGES_COLLECTION}/good-1/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
return good_page_expanded
|
||||
if url == f"{PAGES_COLLECTION}/bad-1/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
raise _make_http_error(
|
||||
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
|
||||
)
|
||||
if url == f"{PAGES_COLLECTION}/page-2/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
return {**second_page, "canvasLayout": {"horizontalSections": []}}
|
||||
raise AssertionError(f"Unexpected call: {url} {params}")
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
|
||||
ids = [p["id"] for p in pages]
|
||||
assert ids == ["good-1", "bad-1", "page-2"]
|
||||
|
||||
def test_non_invalid_request_400_still_raises(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A 400 with a different error code (not invalidRequest) should
|
||||
propagate, not trigger the fallback."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
raise _make_http_error(400, "badRequest", "Something else went wrong")
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
with pytest.raises(HTTPError):
|
||||
list(connector._fetch_site_pages(_site_descriptor()))
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from onyx.db.llm import sync_model_configurations
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
|
||||
|
||||
class TestSyncModelConfigurations:
|
||||
@@ -25,18 +26,18 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o",
|
||||
"display_name": "GPT-4o",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4",
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
SyncModelEntry(
|
||||
name="gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -67,18 +68,18 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4", # Existing - should be skipped
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o", # New - should be inserted
|
||||
"display_name": "GPT-4o",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4", # Existing - should be skipped
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
SyncModelEntry(
|
||||
name="gpt-4o", # New - should be inserted
|
||||
display_name="GPT-4o",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -105,12 +106,12 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
{
|
||||
"name": "gpt-4", # Already exists
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="gpt-4", # Already exists
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -131,7 +132,7 @@ class TestSyncModelConfigurations:
|
||||
sync_model_configurations(
|
||||
db_session=mock_session,
|
||||
provider_name="nonexistent",
|
||||
models=[{"name": "model", "display_name": "Model"}],
|
||||
models=[SyncModelEntry(name="model", display_name="Model")],
|
||||
)
|
||||
|
||||
def test_handles_missing_optional_fields(self) -> None:
|
||||
@@ -145,12 +146,12 @@ class TestSyncModelConfigurations:
|
||||
with patch(
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
# Model with only required fields
|
||||
# Model with only required fields (max_input_tokens and supports_image_input default)
|
||||
models = [
|
||||
{
|
||||
"name": "model-1",
|
||||
# No display_name, max_input_tokens, or supports_image_input
|
||||
},
|
||||
SyncModelEntry(
|
||||
name="model-1",
|
||||
display_name="Model 1",
|
||||
),
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
|
||||
@@ -1,196 +0,0 @@
|
||||
import io
|
||||
|
||||
import openpyxl
|
||||
|
||||
from onyx.file_processing.extract_file_text import xlsx_to_text
|
||||
|
||||
|
||||
def _make_xlsx(sheets: dict[str, list[list[str]]]) -> io.BytesIO:
|
||||
"""Create an in-memory xlsx file from a dict of sheet_name -> matrix of strings."""
|
||||
wb = openpyxl.Workbook()
|
||||
if wb.active is not None:
|
||||
wb.remove(wb.active)
|
||||
for sheet_name, rows in sheets.items():
|
||||
ws = wb.create_sheet(title=sheet_name)
|
||||
for row in rows:
|
||||
ws.append(row)
|
||||
buf = io.BytesIO()
|
||||
wb.save(buf)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
class TestXlsxToText:
|
||||
def test_single_sheet_basic(self) -> None:
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["Name", "Age"],
|
||||
["Alice", "30"],
|
||||
["Bob", "25"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
assert len(lines) == 3
|
||||
assert "Name" in lines[0]
|
||||
assert "Age" in lines[0]
|
||||
assert "Alice" in lines[1]
|
||||
assert "30" in lines[1]
|
||||
assert "Bob" in lines[2]
|
||||
|
||||
def test_multiple_sheets_separated(self) -> None:
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [["a", "b"]],
|
||||
"Sheet2": [["c", "d"]],
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
# TEXT_SECTION_SEPARATOR is "\n\n"
|
||||
assert "\n\n" in result
|
||||
parts = result.split("\n\n")
|
||||
assert any("a" in p for p in parts)
|
||||
assert any("c" in p for p in parts)
|
||||
|
||||
def test_empty_cells(self) -> None:
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["a", "", "b"],
|
||||
["", "c", ""],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
assert len(lines) == 2
|
||||
|
||||
def test_commas_in_cells_are_quoted(self) -> None:
|
||||
"""Cells containing commas should be quoted in CSV output."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["hello, world", "normal"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
assert '"hello, world"' in result
|
||||
|
||||
def test_empty_workbook(self) -> None:
|
||||
xlsx = _make_xlsx({"Sheet1": []})
|
||||
result = xlsx_to_text(xlsx)
|
||||
assert result.strip() == ""
|
||||
|
||||
def test_long_empty_row_run_capped(self) -> None:
|
||||
"""Runs of >2 empty rows should be capped to 2."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["header"],
|
||||
[""],
|
||||
[""],
|
||||
[""],
|
||||
[""],
|
||||
["data"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
# 4 empty rows capped to 2, so: header + 2 empty + data = 4 lines
|
||||
assert len(lines) == 4
|
||||
assert "header" in lines[0]
|
||||
assert "data" in lines[-1]
|
||||
|
||||
def test_long_empty_col_run_capped(self) -> None:
|
||||
"""Runs of >2 empty columns should be capped to 2."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["a", "", "", "", "b"],
|
||||
["c", "", "", "", "d"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
assert len(lines) == 2
|
||||
# Each row should have 4 fields (a + 2 empty + b), not 5
|
||||
# csv format: a,,,b (3 commas = 4 fields)
|
||||
first_line = lines[0].strip()
|
||||
# Count commas to verify column reduction
|
||||
assert first_line.count(",") == 3
|
||||
|
||||
def test_short_empty_runs_kept(self) -> None:
|
||||
"""Runs of <=2 empty rows/cols should be preserved."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["a", "b"],
|
||||
["", ""],
|
||||
["", ""],
|
||||
["c", "d"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
# All 4 rows preserved (2 empty rows <= threshold)
|
||||
assert len(lines) == 4
|
||||
|
||||
def test_bad_zip_file_returns_empty(self) -> None:
|
||||
bad_file = io.BytesIO(b"not a zip file")
|
||||
result = xlsx_to_text(bad_file, file_name="test.xlsx")
|
||||
assert result == ""
|
||||
|
||||
def test_bad_zip_tilde_file_returns_empty(self) -> None:
|
||||
bad_file = io.BytesIO(b"not a zip file")
|
||||
result = xlsx_to_text(bad_file, file_name="~$temp.xlsx")
|
||||
assert result == ""
|
||||
|
||||
def test_large_sparse_sheet(self) -> None:
|
||||
"""A sheet with data, a big empty gap, and more data — gap is capped to 2."""
|
||||
rows: list[list[str]] = [["row1_data"]]
|
||||
rows.extend([[""] for _ in range(10)])
|
||||
rows.append(["row2_data"])
|
||||
xlsx = _make_xlsx({"Sheet1": rows})
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
# 10 empty rows capped to 2: row1_data + 2 empty + row2_data = 4
|
||||
assert len(lines) == 4
|
||||
assert "row1_data" in lines[0]
|
||||
assert "row2_data" in lines[-1]
|
||||
|
||||
def test_quotes_in_cells(self) -> None:
|
||||
"""Cells containing quotes should be properly escaped."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
['say "hello"', "normal"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
# csv.writer escapes quotes by doubling them
|
||||
assert '""hello""' in result
|
||||
|
||||
def test_each_row_is_separate_line(self) -> None:
|
||||
"""Each row should produce its own line (regression for writerow vs writerows)."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["r1c1", "r1c2"],
|
||||
["r2c1", "r2c2"],
|
||||
["r3c1", "r3c2"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
assert len(lines) == 3
|
||||
assert "r1c1" in lines[0] and "r1c2" in lines[0]
|
||||
assert "r2c1" in lines[1] and "r2c2" in lines[1]
|
||||
assert "r3c1" in lines[2] and "r3c2" in lines[2]
|
||||
@@ -26,14 +26,6 @@ class TestIsTrueOpenAIModel:
|
||||
"""Test that real OpenAI GPT-4o-mini model is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "gpt-4o-mini") is True
|
||||
|
||||
def test_real_openai_o1_preview(self) -> None:
|
||||
"""Test that real OpenAI o1-preview reasoning model is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "o1-preview") is True
|
||||
|
||||
def test_real_openai_o1_mini(self) -> None:
|
||||
"""Test that real OpenAI o1-mini reasoning model is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "o1-mini") is True
|
||||
|
||||
def test_openai_with_provider_prefix(self) -> None:
|
||||
"""Test that OpenAI model with provider prefix is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "openai/gpt-4") is False
|
||||
|
||||
204
backend/tests/unit/onyx/onyxbot/test_handle_regular_answer.py
Normal file
204
backend/tests/unit/onyx/onyxbot/test_handle_regular_answer.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""Tests for Slack channel reference resolution and tag filtering
|
||||
in handle_regular_answer.py."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
|
||||
from onyx.onyxbot.slack.handlers.handle_regular_answer import resolve_channel_references
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mock_client_with_channels(
|
||||
channel_map: dict[str, str],
|
||||
) -> MagicMock:
|
||||
"""Return a mock WebClient where conversations_info resolves IDs to names."""
|
||||
client = MagicMock()
|
||||
|
||||
def _conversations_info(channel: str) -> MagicMock:
|
||||
if channel in channel_map:
|
||||
resp = MagicMock()
|
||||
resp.validate = MagicMock()
|
||||
resp.__getitem__ = lambda _self, key: {
|
||||
"channel": {
|
||||
"name": channel_map[channel],
|
||||
"is_im": False,
|
||||
"is_mpim": False,
|
||||
}
|
||||
}[key]
|
||||
return resp
|
||||
raise SlackApiError("channel_not_found", response=MagicMock())
|
||||
|
||||
client.conversations_info = _conversations_info
|
||||
return client
|
||||
|
||||
|
||||
def _mock_logger() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SLACK_CHANNEL_REF_PATTERN regex tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSlackChannelRefPattern:
|
||||
def test_matches_bare_channel_id(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<#C097NBWMY8Y>")
|
||||
assert matches == [("C097NBWMY8Y", "")]
|
||||
|
||||
def test_matches_channel_id_with_name(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<#C097NBWMY8Y|eng-infra>")
|
||||
assert matches == [("C097NBWMY8Y", "eng-infra")]
|
||||
|
||||
def test_matches_multiple_channels(self) -> None:
|
||||
msg = "compare <#C111AAA> and <#C222BBB|general>"
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall(msg)
|
||||
assert len(matches) == 2
|
||||
assert ("C111AAA", "") in matches
|
||||
assert ("C222BBB", "general") in matches
|
||||
|
||||
def test_no_match_on_plain_text(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("no channels here")
|
||||
assert matches == []
|
||||
|
||||
def test_no_match_on_user_mention(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<@U12345>")
|
||||
assert matches == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_channel_references tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveChannelReferences:
|
||||
def test_resolves_bare_channel_id_via_api(self) -> None:
|
||||
client = _mock_client_with_channels({"C097NBWMY8Y": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="summary of <#C097NBWMY8Y> this week",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "summary of #eng-infra this week"
|
||||
assert len(tags) == 1
|
||||
assert tags[0] == Tag(tag_key="Channel", tag_value="eng-infra")
|
||||
|
||||
def test_uses_name_from_pipe_format_without_api_call(self) -> None:
|
||||
client = MagicMock()
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="check <#C097NBWMY8Y|eng-infra> for updates",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "check #eng-infra for updates"
|
||||
assert tags == [Tag(tag_key="Channel", tag_value="eng-infra")]
|
||||
# Should NOT have called the API since name was in the markup
|
||||
client.conversations_info.assert_not_called()
|
||||
|
||||
def test_multiple_channels(self) -> None:
|
||||
client = _mock_client_with_channels(
|
||||
{
|
||||
"C111AAA": "eng-infra",
|
||||
"C222BBB": "eng-general",
|
||||
}
|
||||
)
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="compare <#C111AAA> and <#C222BBB>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "#eng-general" in message
|
||||
assert "<#" not in message
|
||||
assert len(tags) == 2
|
||||
tag_values = {t.tag_value for t in tags}
|
||||
assert tag_values == {"eng-infra", "eng-general"}
|
||||
|
||||
def test_no_channel_references_returns_unchanged(self) -> None:
|
||||
client = MagicMock()
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="just a normal message with no channels",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "just a normal message with no channels"
|
||||
assert tags == []
|
||||
|
||||
def test_api_failure_skips_channel_gracefully(self) -> None:
|
||||
# Client that fails for all channel lookups
|
||||
client = _mock_client_with_channels({})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="check <#CBADID123>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Message should remain unchanged for the failed channel
|
||||
assert "<#CBADID123>" in message
|
||||
assert tags == []
|
||||
logger.warning.assert_called_once()
|
||||
|
||||
def test_partial_failure_resolves_what_it_can(self) -> None:
|
||||
# Only one of two channels resolves
|
||||
client = _mock_client_with_channels({"C111AAA": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="compare <#C111AAA> and <#CBADID123>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "<#CBADID123>" in message # failed one stays raw
|
||||
assert len(tags) == 1
|
||||
assert tags[0].tag_value == "eng-infra"
|
||||
|
||||
def test_duplicate_channel_produces_single_tag(self) -> None:
|
||||
client = _mock_client_with_channels({"C111AAA": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="summarize <#C111AAA> and compare with <#C111AAA>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "summarize #eng-infra and compare with #eng-infra"
|
||||
assert len(tags) == 1
|
||||
assert tags[0].tag_value == "eng-infra"
|
||||
|
||||
def test_mixed_pipe_and_bare_formats(self) -> None:
|
||||
client = _mock_client_with_channels({"C222BBB": "random"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="see <#C111AAA|eng-infra> and <#C222BBB>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "#random" in message
|
||||
assert len(tags) == 2
|
||||
@@ -1,43 +0,0 @@
|
||||
"""Unit tests for _get_user_access_info helper function.
|
||||
|
||||
These tests mock all database operations and don't require a real database.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.server.features.hierarchy.api import _get_user_access_info
|
||||
|
||||
|
||||
def test_get_user_access_info_returns_email_and_groups() -> None:
|
||||
"""_get_user_access_info returns the user's email and external group IDs."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = "test@example.com"
|
||||
mock_db_session = MagicMock(spec=Session)
|
||||
|
||||
with patch(
|
||||
"onyx.server.features.hierarchy.api.get_user_external_group_ids",
|
||||
return_value=["group1", "group2"],
|
||||
):
|
||||
email, groups = _get_user_access_info(mock_user, mock_db_session)
|
||||
|
||||
assert email == "test@example.com"
|
||||
assert groups == ["group1", "group2"]
|
||||
|
||||
|
||||
def test_get_user_access_info_with_no_groups() -> None:
|
||||
"""User with no external groups returns empty list."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = "solo@example.com"
|
||||
mock_db_session = MagicMock(spec=Session)
|
||||
|
||||
with patch(
|
||||
"onyx.server.features.hierarchy.api.get_user_external_group_ids",
|
||||
return_value=[],
|
||||
):
|
||||
email, groups = _get_user_access_info(mock_user, mock_db_session)
|
||||
|
||||
assert email == "solo@example.com"
|
||||
assert groups == []
|
||||
@@ -1,15 +1,19 @@
|
||||
"""Tests for LLM model fetch endpoints.
|
||||
|
||||
These tests verify the full request/response flow for fetching models
|
||||
from dynamic providers (Ollama, OpenRouter), including the
|
||||
from dynamic providers (Ollama, OpenRouter, Litellm), including the
|
||||
sync-to-DB behavior when provider_name is specified.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelsRequest
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LMStudioModelsRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
@@ -614,3 +618,283 @@ class TestGetLMStudioAvailableModels:
|
||||
request = LMStudioModelsRequest(api_base="http://localhost:1234")
|
||||
with pytest.raises(OnyxError):
|
||||
get_lm_studio_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
|
||||
class TestGetLitellmAvailableModels:
|
||||
"""Tests for the Litellm proxy model fetch endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_response(self) -> dict:
|
||||
"""Mock response from Litellm /v1/models endpoint."""
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
{
|
||||
"id": "claude-3-5-sonnet",
|
||||
"object": "model",
|
||||
"created": 1700000001,
|
||||
"owned_by": "anthropic",
|
||||
},
|
||||
{
|
||||
"id": "gemini-pro",
|
||||
"object": "model",
|
||||
"created": 1700000002,
|
||||
"owned_by": "google",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
def test_returns_model_list(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that endpoint returns properly formatted model list."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(r, LitellmFinalModelResponse) for r in results)
|
||||
|
||||
def test_model_fields_parsed_correctly(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that provider_name and model_name are correctly extracted."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
gpt = next(r for r in results if r.model_name == "gpt-4o")
|
||||
assert gpt.provider_name == "openai"
|
||||
|
||||
claude = next(r for r in results if r.model_name == "claude-3-5-sonnet")
|
||||
assert claude.provider_name == "anthropic"
|
||||
|
||||
def test_results_sorted_by_model_name(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that results are alphabetically sorted by model_name."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
model_names = [r.model_name for r in results]
|
||||
assert model_names == sorted(model_names, key=str.lower)
|
||||
|
||||
def test_empty_data_raises_onyx_error(self) -> None:
|
||||
"""Test that empty model list raises OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"data": []}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="No models found"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_missing_data_key_raises_onyx_error(self) -> None:
|
||||
"""Test that response without 'data' key raises OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_skips_unparseable_entries(self) -> None:
|
||||
"""Test that malformed model entries are skipped without failing."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response_with_bad_entry = {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
# Missing required fields
|
||||
{"bad_field": "bad_value"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response_with_bad_entry
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].model_name == "gpt-4o"
|
||||
|
||||
def test_all_entries_unparseable_raises_onyx_error(self) -> None:
|
||||
"""Test that OnyxError is raised when all entries fail to parse."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response_all_bad = {
|
||||
"data": [
|
||||
{"bad_field": "bad_value"},
|
||||
{"another_bad": 123},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response_all_bad
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="No compatible models"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_api_base_trailing_slash_handled(self) -> None:
|
||||
"""Test that trailing slashes in api_base are handled correctly."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_litellm_response = {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000/",
|
||||
api_key="test-key",
|
||||
)
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
# Should call /v1/models without double slashes
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[0][0] == "http://localhost:4000/v1/models"
|
||||
|
||||
def test_connection_failure_raises_onyx_error(self) -> None:
|
||||
"""Test that connection failures are wrapped in OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_get.side_effect = Exception("Connection refused")
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="Failed to fetch LiteLLM models"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_401_raises_authentication_error(self) -> None:
|
||||
"""Test that a 401 response raises OnyxError with authentication message."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_get.side_effect = httpx.HTTPStatusError(
|
||||
"Unauthorized", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="bad-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="Authentication failed"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_404_raises_not_found_error(self) -> None:
|
||||
"""Test that a 404 response raises OnyxError with endpoint not found message."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
mock_get.side_effect = httpx.HTTPStatusError(
|
||||
"Not Found", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="endpoint not found"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
@@ -1,188 +0,0 @@
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
|
||||
from onyx.server.features.projects import projects_file_utils as utils
|
||||
|
||||
|
||||
class _Tokenizer:
|
||||
def encode(self, text: str) -> list[int]:
|
||||
return [1] * len(text)
|
||||
|
||||
|
||||
class _NonSeekableFile(BytesIO):
|
||||
def tell(self) -> int:
|
||||
raise OSError("tell not supported")
|
||||
|
||||
def seek(self, *_args: object, **_kwargs: object) -> int:
|
||||
raise OSError("seek not supported")
|
||||
|
||||
|
||||
def _make_upload(filename: str, size: int, content: bytes | None = None) -> UploadFile:
|
||||
payload = content if content is not None else (b"x" * size)
|
||||
return UploadFile(filename=filename, file=BytesIO(payload), size=size)
|
||||
|
||||
|
||||
def _make_upload_no_size(filename: str, content: bytes) -> UploadFile:
|
||||
return UploadFile(filename=filename, file=BytesIO(content), size=None)
|
||||
|
||||
|
||||
def _patch_common_dependencies(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(utils, "fetch_default_llm_model", lambda _db: None)
|
||||
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _Tokenizer())
|
||||
monkeypatch.setattr(utils, "is_file_password_protected", lambda **_kwargs: False)
|
||||
|
||||
|
||||
def test_get_upload_size_bytes_falls_back_to_stream_size() -> None:
|
||||
upload = UploadFile(filename="example.txt", file=BytesIO(b"abcdef"), size=None)
|
||||
upload.file.seek(2)
|
||||
|
||||
size = utils.get_upload_size_bytes(upload)
|
||||
|
||||
assert size == 6
|
||||
assert upload.file.tell() == 2
|
||||
|
||||
|
||||
def test_get_upload_size_bytes_logs_warning_when_stream_size_unavailable(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
upload = UploadFile(filename="non_seekable.txt", file=_NonSeekableFile(), size=None)
|
||||
|
||||
caplog.set_level("WARNING")
|
||||
size = utils.get_upload_size_bytes(upload)
|
||||
|
||||
assert size is None
|
||||
assert "Could not determine upload size via stream seek" in caplog.text
|
||||
assert "non_seekable.txt" in caplog.text
|
||||
|
||||
|
||||
def test_is_upload_too_large_logs_warning_when_size_unknown(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
upload = _make_upload("size_unknown.txt", size=1)
|
||||
monkeypatch.setattr(utils, "get_upload_size_bytes", lambda _upload: None)
|
||||
|
||||
caplog.set_level("WARNING")
|
||||
is_too_large = utils.is_upload_too_large(upload, max_bytes=100)
|
||||
|
||||
assert is_too_large is False
|
||||
assert "Could not determine upload size; skipping size-limit check" in caplog.text
|
||||
assert "size_unknown.txt" in caplog.text
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_accepts_size_under_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("small.png", size=99)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload_no_size("small.png", content=b"x" * 99)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_accepts_size_at_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("edge.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("large.png", size=101)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
small = _make_upload("small.png", size=50)
|
||||
large = _make_upload("large.png", size=101)
|
||||
|
||||
result = utils.categorize_uploaded_files([small, large], MagicMock())
|
||||
|
||||
assert [file.filename for file in result.acceptable] == ["small.png"]
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].filename == "large.png"
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_skipped(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "SKIP_USERFILE_THRESHOLD", True)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
upload = _make_upload("oversized.pdf", size=101)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_checks_size_before_text_extraction(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
extract_mock = MagicMock(return_value="this should not run")
|
||||
monkeypatch.setattr(utils, "extract_file_text", extract_mock)
|
||||
|
||||
oversized_doc = _make_upload("oversized.pdf", size=101)
|
||||
result = utils.categorize_uploaded_files([oversized_doc], MagicMock())
|
||||
|
||||
extract_mock.assert_not_called()
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
@@ -1,32 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.settings import store as settings_store
|
||||
|
||||
|
||||
class _FakeKvStore:
|
||||
def load(self, _key: str) -> dict:
|
||||
raise KvKeyNotFoundError()
|
||||
|
||||
|
||||
class _FakeCache:
|
||||
def __init__(self) -> None:
|
||||
self._vals: dict[str, bytes] = {}
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
return self._vals.get(key)
|
||||
|
||||
def set(self, key: str, value: str, ex: int | None = None) -> None: # noqa: ARG002
|
||||
self._vals[key] = value.encode("utf-8")
|
||||
|
||||
|
||||
def test_load_settings_includes_user_file_max_upload_size_mb(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "USER_FILE_MAX_UPLOAD_SIZE_MB", 77)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 77
|
||||
@@ -147,18 +147,15 @@ class TestSensitiveValueString:
|
||||
)
|
||||
assert sensitive1 != sensitive2
|
||||
|
||||
def test_equality_with_non_sensitive_returns_not_equal(self) -> None:
|
||||
"""Test that comparing with non-SensitiveValue is always not-equal.
|
||||
|
||||
Returns NotImplemented so Python falls back to identity comparison.
|
||||
This is required for compatibility with SQLAlchemy's attribute tracking.
|
||||
"""
|
||||
def test_equality_with_non_sensitive_raises(self) -> None:
|
||||
"""Test that comparing with non-SensitiveValue raises error."""
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("secret"),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
assert not (sensitive == "secret")
|
||||
with pytest.raises(SensitiveAccessError):
|
||||
_ = sensitive == "secret"
|
||||
|
||||
|
||||
class TestSensitiveValueJson:
|
||||
|
||||
@@ -158,14 +158,14 @@ python ./scripts/dev_run_background_jobs.py
|
||||
To run the backend API server, navigate back to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
AUTH_TYPE=basic uvicorn onyx.main:app --reload --port 8080
|
||||
AUTH_TYPE=disabled uvicorn onyx.main:app --reload --port 8080
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "
|
||||
$env:AUTH_TYPE='basic'
|
||||
$env:AUTH_TYPE='disabled'
|
||||
uvicorn onyx.main:app --reload --port 8080
|
||||
"
|
||||
```
|
||||
|
||||
@@ -21,7 +21,7 @@ services:
|
||||
env_file:
|
||||
- .env_eval
|
||||
environment:
|
||||
- AUTH_TYPE=basic
|
||||
- AUTH_TYPE=disabled
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
@@ -58,7 +58,7 @@ services:
|
||||
env_file:
|
||||
- .env_eval
|
||||
environment:
|
||||
- AUTH_TYPE=basic
|
||||
- AUTH_TYPE=disabled
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
|
||||
@@ -20,12 +20,8 @@ IMAGE_TAG=latest
|
||||
|
||||
## Auth Settings
|
||||
### https://docs.onyx.app/deployment/authentication
|
||||
AUTH_TYPE=basic
|
||||
AUTH_TYPE=disabled
|
||||
# SESSION_EXPIRE_TIME_SECONDS=
|
||||
### Recommended for basic auth - used for signing password reset and verification tokens
|
||||
### If using install.sh, this will be auto-generated
|
||||
### If setting manually, run: openssl rand -hex 32
|
||||
USER_AUTH_SECRET=""
|
||||
### Recommend to set this for security
|
||||
# ENCRYPTION_KEY_SECRET=
|
||||
### Optional
|
||||
|
||||
@@ -654,20 +654,17 @@ else
|
||||
sed -i.bak "s/^IMAGE_TAG=.*/IMAGE_TAG=$VERSION/" "$ENV_FILE"
|
||||
print_success "IMAGE_TAG set to $VERSION"
|
||||
|
||||
# Configure basic authentication (default)
|
||||
sed -i.bak 's/^AUTH_TYPE=.*/AUTH_TYPE=basic/' "$ENV_FILE" 2>/dev/null || true
|
||||
print_success "Basic authentication enabled in configuration"
|
||||
|
||||
# Check if openssl is available
|
||||
if ! command -v openssl &> /dev/null; then
|
||||
print_error "openssl is required to generate secure secrets but was not found."
|
||||
exit 1
|
||||
# Configure authentication settings based on selection
|
||||
if [ "$AUTH_SCHEMA" = "disabled" ]; then
|
||||
# Disable authentication in .env file
|
||||
sed -i.bak 's/^AUTH_TYPE=.*/AUTH_TYPE=disabled/' "$ENV_FILE" 2>/dev/null || true
|
||||
print_success "Authentication disabled in configuration"
|
||||
else
|
||||
# Enable basic authentication
|
||||
sed -i.bak 's/^AUTH_TYPE=.*/AUTH_TYPE=basic/' "$ENV_FILE" 2>/dev/null || true
|
||||
print_success "Basic authentication enabled in configuration"
|
||||
fi
|
||||
|
||||
# Generate a secure USER_AUTH_SECRET
|
||||
USER_AUTH_SECRET=$(openssl rand -hex 32)
|
||||
sed -i.bak "s/^USER_AUTH_SECRET=.*/USER_AUTH_SECRET=\"$USER_AUTH_SECRET\"/" "$ENV_FILE" 2>/dev/null || true
|
||||
|
||||
# Configure Craft based on flag or if using a craft-* image tag
|
||||
# By default, env.template has Craft commented out (disabled)
|
||||
if [ "$INCLUDE_CRAFT" = true ] || [[ "$VERSION" == craft-* ]]; then
|
||||
|
||||
@@ -1266,5 +1266,3 @@ configMap:
|
||||
SKIP_USERFILE_THRESHOLD: ""
|
||||
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_IDS: ""
|
||||
# Maximum user upload file size in MB for chat/projects uploads
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB: ""
|
||||
|
||||
@@ -143,7 +143,7 @@ dev = [
|
||||
"matplotlib==3.10.8",
|
||||
"mypy-extensions==1.0.0",
|
||||
"mypy==1.13.0",
|
||||
"onyx-devtools==0.6.3",
|
||||
"onyx-devtools==0.7.0",
|
||||
"openapi-generator-cli==7.17.0",
|
||||
"pandas-stubs~=2.3.3",
|
||||
"pre-commit==3.2.2",
|
||||
@@ -153,7 +153,7 @@ dev = [
|
||||
"pytest-repeat==0.9.4",
|
||||
"pytest-xdist==3.8.0",
|
||||
"pytest==8.3.5",
|
||||
"release-tag==0.4.3",
|
||||
"release-tag==0.5.2",
|
||||
"reorder-python-imports-black==3.14.0",
|
||||
"ruff==0.12.0",
|
||||
"types-beautifulsoup4==4.12.0.3",
|
||||
|
||||
35
uv.lock
generated
35
uv.lock
generated
@@ -4443,7 +4443,7 @@ requires-dist = [
|
||||
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
|
||||
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
|
||||
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.6.2" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.3" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.7.0" },
|
||||
{ name = "openai", specifier = "==2.14.0" },
|
||||
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
|
||||
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
|
||||
@@ -4485,7 +4485,7 @@ requires-dist = [
|
||||
{ name = "pywikibot", marker = "extra == 'backend'", specifier = "==9.0.0" },
|
||||
{ name = "rapidfuzz", marker = "extra == 'backend'", specifier = "==3.13.0" },
|
||||
{ name = "redis", marker = "extra == 'backend'", specifier = "==5.0.8" },
|
||||
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.4.3" },
|
||||
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.5.2" },
|
||||
{ name = "reorder-python-imports-black", marker = "extra == 'dev'", specifier = "==3.14.0" },
|
||||
{ name = "requests", marker = "extra == 'backend'", specifier = "==2.32.5" },
|
||||
{ name = "requests-oauthlib", marker = "extra == 'backend'", specifier = "==1.3.1" },
|
||||
@@ -4548,20 +4548,19 @@ requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable =
|
||||
|
||||
[[package]]
|
||||
name = "onyx-devtools"
|
||||
version = "0.6.3"
|
||||
version = "0.7.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "fastapi" },
|
||||
{ name = "openapi-generator-cli" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/84/e2/e7619722c3ccd18eb38100f776fb3dd6b4ae0fbbee09fca5af7c69a279b5/onyx_devtools-0.6.3-py3-none-any.whl", hash = "sha256:d3a5422945d9da12cafc185f64b39f6e727ee4cc92b37427deb7a38f9aad4966", size = 3945381, upload-time = "2026-03-05T20:39:25.896Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/09/513d2dabedc1e54ad4376830fc9b34a3d9c164bdbcdedfcdbb8b8154dc5a/onyx_devtools-0.6.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:efe300e9f3a2e7ae75f88a4f9e0a5c4c471478296cb1615b6a1f03d247582e13", size = 3978761, upload-time = "2026-03-05T20:39:28.822Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/39/41/e757602a0de032d74ed01c7ee57f30e57728fb9cd4f922f50d2affda3889/onyx_devtools-0.6.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:594066eed3f917cfab5a8c7eac3d4a210df30259f2049f664787749709345e19", size = 3665378, upload-time = "2026-03-05T20:44:22.696Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/1c/c93b65d0b32e202596a2647922a75c7011cb982f899ddfcfd171f792c58f/onyx_devtools-0.6.3-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:384ef66030b55c0fd68b3898782b5b4b868ff3de119569dfc8544e2ce534b98a", size = 3540890, upload-time = "2026-03-05T20:39:28.886Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f4/33/760eb656013f7f0cdff24570480d3dc4e52bbd8e6147ea1e8cf6fad7554f/onyx_devtools-0.6.3-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:82e218f3a49f64910c2c4c34d5dc12d1ea1520a27e0b0f6e4c0949ff9abaf0e1", size = 3945396, upload-time = "2026-03-05T20:39:34.323Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1a/eb/f54b3675c464df8a51194ff75afc97c2417659e3a209dc46948b47c28860/onyx_devtools-0.6.3-py3-none-win_amd64.whl", hash = "sha256:8af614ae7229290ef2417cb85270184a1e826ed9a3a34658da93851edb36df57", size = 4045936, upload-time = "2026-03-05T20:39:28.375Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/b8/5bee38e748f3d4b8ec935766224db1bbc1214c91092e5822c080fccd9130/onyx_devtools-0.6.3-py3-none-win_arm64.whl", hash = "sha256:717589db4b42528d33ae96f8006ee6aad3555034dcfee724705b6576be6a6ec4", size = 3608268, upload-time = "2026-03-05T20:39:28.731Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/22/9e/6957b11555da57d9e97092f4cd8ac09a86666264b0c9491838f4b27db5dc/onyx_devtools-0.7.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ad962a168d46ea11dcde9fa3b37e4f12ec520b4a4cb4d49d8732de110d46c4b6", size = 3998057, upload-time = "2026-03-12T03:09:11.585Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/90/c72f3d06ba677012d77c77de36195b6a32a15c755c79ba0282be74e3c366/onyx_devtools-0.7.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e46d252e2b048ff053b03519c3a875998780738d7c334eaa1c9a32ff445e3e1a", size = 3687753, upload-time = "2026-03-12T03:09:11.742Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/42/4e9fe36eccf9f76d67ba8f4ff6539196a09cd60351fb63f5865e1544cbfa/onyx_devtools-0.7.0-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:f280bc9320e1cc310e7d753a371009bfaab02cc0e0cfd78559663b15655b5a50", size = 3560144, upload-time = "2026-03-12T03:12:24.02Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/76/40/36dc12d99760b358c7f39b27361cb18fa9681ffe194107f982d0e1a74016/onyx_devtools-0.7.0-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:e31df751c7540ae7e70a7fe8e1153c79c31c2254af6aa4c72c0dd54fa381d2ab", size = 3964387, upload-time = "2026-03-12T03:09:11.356Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/34/18/74744230c3820a5a7687335507ca5f1dbebab2c5325805041c1cd5703e6a/onyx_devtools-0.7.0-py3-none-win_amd64.whl", hash = "sha256:541bfd347c2d5b11e7f63ab5001d2594df91d215ad9d07b1562f5e715700f7e6", size = 4068030, upload-time = "2026-03-12T03:09:12.98Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/78/1320436607d3ffcb321ba7b064556c020ea15843a7e7d903fbb7529a71f5/onyx_devtools-0.7.0-py3-none-win_arm64.whl", hash = "sha256:83016330a9d39712431916cc25b2fb2cfcaa0112a55cc4f919d545da3a8974f9", size = 3626409, upload-time = "2026-03-12T03:09:10.222Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6338,16 +6337,16 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "release-tag"
|
||||
version = "0.4.3"
|
||||
version = "0.5.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/39/18/c1d17d973f73f0aa7e2c45f852839ab909756e1bd9727d03babe400fcef0/release_tag-0.4.3-py3-none-any.whl", hash = "sha256:4206f4fa97df930c8176bfee4d3976a7385150ed14b317bd6bae7101ac8b66dd", size = 1181112, upload-time = "2025-12-03T00:18:19.445Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/c7/ecc443953840ac313856b2181f55eb8d34fa2c733cdd1edd0bcceee0938d/release_tag-0.4.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7a347a9ad3d2af16e5367e52b451fbc88a0b7b666850758e8f9a601554a8fb13", size = 1170517, upload-time = "2025-12-03T00:18:11.663Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/81/2f6ffa0d87c792364ca9958433fe088c8acc3d096ac9734040049c6ad506/release_tag-0.4.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2d1603aa37d8e4f5df63676bbfddc802fbc108a744ba28288ad25c997981c164", size = 1101663, upload-time = "2025-12-03T00:18:15.173Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/ed/9e4ebe400fc52e38dda6e6a45d9da9decd4535ab15e170b8d9b229a66730/release_tag-0.4.3-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:6db7b81a198e3ba6a87496a554684912c13f9297ea8db8600a80f4f971709d37", size = 1079322, upload-time = "2025-12-03T00:18:16.094Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/64/9e0ce6119e091ef9211fa82b9593f564eeec8bdd86eff6a97fe6e2fcb20f/release_tag-0.4.3-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:d79a9cf191dd2c29e1b3a35453fa364b08a7aadd15aeb2c556a7661c6cf4d5ad", size = 1181129, upload-time = "2025-12-03T00:18:15.82Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/09/d96acf18f0773b6355080a568ba48931faa9dbe91ab1abefc6f8c4df04a8/release_tag-0.4.3-py3-none-win_amd64.whl", hash = "sha256:3958b880375f2241d0cc2b9882363bf54b1d4d7ca8ffc6eecc63ab92f23307f0", size = 1260773, upload-time = "2025-12-03T00:18:14.723Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/da/ecb6346df1ffb0752fe213e25062f802c10df2948717f0d5f9816c2df914/release_tag-0.4.3-py3-none-win_arm64.whl", hash = "sha256:7d5b08000e6e398d46f05a50139031046348fba6d47909f01e468bb7600c19df", size = 1142155, upload-time = "2025-12-03T00:18:20.647Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/92/01192a540b29cfadaa23850c8f6a2041d541b83a3fa1dc52a5f55212b3b6/release_tag-0.5.2-py3-none-any.whl", hash = "sha256:1e9ca7618bcfc63ad7a0728c84bbad52ef82d07586c4cc11365b44ea8f588069", size = 1264752, upload-time = "2026-03-11T00:27:18.674Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/77/81fb42a23cd0de61caf84266f7aac1950b1c324883788b7c48e5344f61ae/release_tag-0.5.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8fbc61ff7bac2b96fab09566ec45c6508c201efc3f081f57702e1761bbc178d5", size = 1255075, upload-time = "2026-03-11T00:27:24.442Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/e6/769f8be94304529c1a531e995f2f3ac83f3c54738ce488b0abde75b20851/release_tag-0.5.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa3d7e495a0c516858a81878d03803539712677a3d6e015503de21cce19bea5e", size = 1163627, upload-time = "2026-03-11T00:27:26.412Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/68/7543e9daa0dfd41c487bf140d91fd5879327bb7c001a96aa5264667c30a1/release_tag-0.5.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:e8b60453218d6926da1fdcb99c2e17c851be0d7ab1975e97951f0bff5f32b565", size = 1140133, upload-time = "2026-03-11T00:27:20.633Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/30/9087825696271012d889d136310dbdf0811976ae2b2f5a490f4e437903e1/release_tag-0.5.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:0e302ed60c2bf8b7ba5634842be28a27d83cec995869e112b0348b3f01a84ff5", size = 1264767, upload-time = "2026-03-11T00:27:28.355Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/a3/5b51b0cbdbf2299f545124beab182cfdfe01bf5b615efbc94aee3a64ea67/release_tag-0.5.2-py3-none-win_amd64.whl", hash = "sha256:e3c0629d373a16b9a3da965e89fca893640ce9878ec548865df3609b70989a89", size = 1340816, upload-time = "2026-03-11T00:27:22.622Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/6f/832c2023a8bd8414c93452bd8b43bf61cedfa5b9575f70c06fb911e51a29/release_tag-0.5.2-py3-none-win_arm64.whl", hash = "sha256:5f26b008e0be0c7a122acd8fcb1bb5c822f38e77fed0c0bf6c550cc226c6bf14", size = 1203191, upload-time = "2026-03-11T00:27:29.789Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
3
web/.gitignore
vendored
3
web/.gitignore
vendored
@@ -47,6 +47,3 @@ next-env.d.ts
|
||||
# generated clients ... in particular, the API to the Onyx backend itself!
|
||||
/src/lib/generated
|
||||
.jest-cache
|
||||
|
||||
# storybook
|
||||
storybook-static
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
import { Meta } from "@storybook/blocks";
|
||||
|
||||
<Meta title="Getting Started" />
|
||||
|
||||
# Onyx Storybook
|
||||
|
||||
A living catalog for browsing, testing, and documenting Onyx UI components in isolation.
|
||||
|
||||
---
|
||||
|
||||
## What is this?
|
||||
|
||||
This Storybook contains interactive examples of every reusable UI component in the Onyx frontend. Each component has a dedicated page with:
|
||||
|
||||
- **Live demos** you can interact with directly
|
||||
- **Controls** to tweak props and see how the component responds
|
||||
- **Auto-generated docs** showing the full props API
|
||||
- **Dark mode toggle** in the toolbar to preview both themes
|
||||
|
||||
---
|
||||
|
||||
## Navigating Storybook
|
||||
|
||||
### Sidebar
|
||||
|
||||
The left sidebar organizes components by layer:
|
||||
|
||||
- **opal/core** — Low-level primitives (`Interactive`, `Hoverable`)
|
||||
- **opal/components** — Design system atoms (`Button`, `OpenButton`, `Tag`)
|
||||
- **Layouts** — Structural layouts (`Content`, `ContentAction`, `IllustrationContent`)
|
||||
- **refresh-components** — App-level components (inputs, modals, tables, text, etc.)
|
||||
|
||||
Click any component to see its stories. Click **Docs** to see the auto-generated props table.
|
||||
|
||||
### Controls panel
|
||||
|
||||
At the bottom of each story, the **Controls** panel lets you change props in real time. Toggle booleans, pick from enums, type in strings — the preview updates instantly.
|
||||
|
||||
### Theme toggle
|
||||
|
||||
Use the paint roller icon in the top toolbar to switch between **light** and **dark** mode. All components use CSS variables that automatically adapt.
|
||||
|
||||
---
|
||||
|
||||
## Running locally
|
||||
|
||||
```bash
|
||||
cd web
|
||||
npm run storybook # dev server on :6006
|
||||
npm run storybook:build # static build to storybook-static/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Adding a new story
|
||||
|
||||
Stories are **co-located** next to their component:
|
||||
|
||||
```
|
||||
lib/opal/src/components/buttons/Button/
|
||||
├── components.tsx ← the component
|
||||
├── Button.stories.tsx ← the story
|
||||
├── styles.css
|
||||
└── README.md
|
||||
```
|
||||
|
||||
### Minimal template
|
||||
|
||||
```tsx
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { MyComponent } from "./MyComponent";
|
||||
|
||||
const meta: Meta<typeof MyComponent> = {
|
||||
title: "opal/components/MyComponent", // sidebar path
|
||||
component: MyComponent,
|
||||
tags: ["autodocs"], // auto-generate docs page
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof MyComponent>;
|
||||
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
title: "Hello",
|
||||
},
|
||||
};
|
||||
|
||||
export const WithCustomLayout: Story = {
|
||||
render: () => (
|
||||
<div className="flex gap-2">
|
||||
<MyComponent title="One" />
|
||||
<MyComponent title="Two" />
|
||||
</div>
|
||||
),
|
||||
};
|
||||
```
|
||||
|
||||
### Conventions
|
||||
|
||||
- **Title format:** `opal/core/Name`, `opal/components/Name`, `Layouts/Name`, or `refresh-components/Name`
|
||||
- **Tags:** Add `tags: ["autodocs"]` to auto-generate a docs page from props
|
||||
- **Decorators:** If your component needs `TooltipPrimitive.Provider` (anything with tooltips), add it as a decorator
|
||||
- **Layout:** Use `parameters: { layout: "fullscreen" }` for modals/popovers that use portals
|
||||
|
||||
---
|
||||
|
||||
## Deployment
|
||||
|
||||
Production builds deploy to [onyx-storybook.vercel.app](https://onyx-storybook.vercel.app) automatically when PRs touching component files merge to `main`.
|
||||
|
||||
Monitored paths:
|
||||
|
||||
- `web/lib/opal/**`
|
||||
- `web/src/refresh-components/**`
|
||||
- `web/.storybook/**`
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user