Compare commits

..

31 Commits

Author SHA1 Message Date
Evan Lohn
923e0691aa fix: sharepoint pages 400 list expand (#9321) 2026-03-13 11:43:08 -07:00
Evan Lohn
b232e2a771 fix: skip classic site pages (#9318) 2026-03-12 22:08:49 -07:00
Evan Lohn
c3ebfeda2f chore: sharepoint error logs (#9309) 2026-03-12 21:37:50 -07:00
github-actions[bot]
6a28dfedb1 fix(fe): prevent clicking InputSelect from selecting text (#9292) to release v3.0 (#9306)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-12 09:36:25 -07:00
github-actions[bot]
a123ec083d chore(devtools): upgrade ods: 0.6.3->0.7.0 (#9297) to release v3.0 (#9298)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-11 20:37:22 -07:00
Nikolas Garza
f448f1274d fix(slackbot): resolve channel references and filter search by channel tags (#9256) to release v3.0 (#9294) 2026-03-11 20:23:28 -07:00
Jamison Lahman
d12f8b94aa fix(fe): InputComboBox resets filter value on open (#9287) to release v3.0 (#9291) 2026-03-11 18:36:36 -07:00
Bo-Onyx
355fe2ff2c fix(api memory): replace glibc with jemalloc for memory allocating (#9196) to release v3.0 (#9282)
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
2026-03-11 14:58:43 -07:00
Nikolas Garza
8ec5423a0c fix(tests): remove deprecated o1-preview and o1-mini model tests (#9280) 2026-03-11 14:37:03 -07:00
Justin Tahara
79b615db46 feat(litellm): Adding FE Provider workflow (#9264) 2026-03-11 11:56:04 -07:00
Wenxi
98756bccd4 fix: discord connector async resource cleanup (#9203) 2026-03-11 11:53:57 -07:00
Wenxi
418f84ccdf fix: don't fetch mcp tools when no llms are configured (#9173) 2026-03-11 11:53:57 -07:00
Wenxi
d37756a884 fix(mcp): use CE-compatible chat endpoint for search_indexed_documents (#9193)
Co-authored-by: Fizza-Mukhtar <fizzamukhtar01@gmail.com>
2026-03-11 11:53:57 -07:00
Wenxi
9cdc92441b fix: fallback doc access when drive item is externally owned (#9053) 2026-03-11 11:53:57 -07:00
Wenxi
b8ed30644a fix: move available context tokens to useChatController and remove arbitrary 50% cap (#9174) 2026-03-11 11:53:57 -07:00
Danelegend
d7d19e5a28 feat(llm-provider): fetch litellm models (#8418) 2026-03-11 10:55:51 -07:00
github-actions[bot]
948650829d chore(release): upgrade release-tag (#9257) to release v3.0 (#9261)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-10 18:16:55 -07:00
Evan Lohn
b6e689be0f fix: update jira group sync endpoint (#9241) 2026-03-10 17:13:10 -07:00
github-actions[bot]
85877408c8 fix(fe): increase responsive breakpoint for centering modals (#9250) to release v3.0 (#9251)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-10 14:59:46 -07:00
github-actions[bot]
c00df75c79 fix(fe): correctly parse comma literals in CSVs (#9245) to release v3.0 (#9249)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-10 14:13:40 -07:00
github-actions[bot]
6352c9a09e fix(fe): make CSV inline display responsive (#9242) to release v3.0 (#9246)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-10 13:19:54 -07:00
github-actions[bot]
3065f70d7d fix: Prevent the removal and hiding of default model (#9131) to release v3.0 (#9225)
Co-authored-by: Danelegend <43459662+Danelegend@users.noreply.github.com>
2026-03-10 10:54:21 -07:00
Jamison Lahman
4befbc49dc chore(release): run playwright on release pushes (#9233) to release v3.0 (#9238) 2026-03-10 10:29:22 -07:00
Jamison Lahman
ae9679e8c4 fix(safari): Search results dont shrink (#9126) to release v3.0 (#9210) 2026-03-10 09:11:32 -07:00
SubashMohan
ea0ddee5c8 feat(custom-tools): enhance custom tool error handling and timeline UI (#9189) 2026-03-10 17:04:15 +05:30
Nikolas Garza
2826405dd2 fix: use detail instead of message in OnyxError response shape (#9214) to release v3.0 (#9220) 2026-03-09 19:15:29 -07:00
github-actions[bot]
8485bf4368 fix(fe): fix chat content padding (#9216) to release v3.0 (#9218)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-09 17:57:37 -07:00
github-actions[bot]
7bb52b0839 fix(code-interpreter): set default CODE_INTERPRETER_BASE_URL w/ docke… (#9215) to release v3.0 (#9219)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-09 17:57:21 -07:00
github-actions[bot]
85a54c01f1 feat(opensearch): Enable by default (#9211) to release v3.0 (#9217)
Co-authored-by: acaprau <48705707+acaprau@users.noreply.github.com>
2026-03-09 17:35:44 -07:00
github-actions[bot]
e4577bd564 fix(fe): move app padding inside overflow container (#9206) to release v3.0 (#9207)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-09 13:47:36 -07:00
Nikolas Garza
f150a7b940 fix(fe): fix broken slack bot admin pages (#9168) 2026-03-09 13:01:58 -07:00
266 changed files with 3107 additions and 15184 deletions

View File

@@ -151,7 +151,7 @@ jobs:
fetch-depth: 0
- name: Setup uv
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
version: "0.9.9"
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.

View File

@@ -70,7 +70,7 @@ jobs:
- name: Install the latest version of uv
if: steps.gate.outputs.should_cherrypick == 'true'
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -471,7 +471,7 @@ jobs:
- name: Install the latest version of uv
if: always()
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
@@ -710,7 +710,7 @@ jobs:
pull-requests: write
steps:
- name: Download visual diff summaries
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131
with:
pattern: screenshot-diff-summary-*
path: summaries/

View File

@@ -28,7 +28,7 @@ jobs:
with:
python-version: "3.11"
- name: Setup Terraform
uses: hashicorp/setup-terraform@5e8dbf3c6d9deaf4193ca7a8fb23f2ac83bb6c85 # ratchet:hashicorp/setup-terraform@v4.0.0
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
- name: Setup node
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6
with: # zizmor: ignore[cache-poisoning]

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -1,69 +0,0 @@
name: Storybook Deploy
env:
VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }}
VERCEL_PROJECT_ID: prj_sG49mVsA25UsxIPhN2pmBJlikJZM
VERCEL_CLI: vercel@50.14.1
VERCEL_TOKEN: ${{ secrets.VERCEL_TOKEN }}
concurrency:
group: storybook-deploy-production
cancel-in-progress: true
on:
workflow_dispatch:
push:
branches:
- main
paths:
- "web/lib/opal/**"
- "web/src/refresh-components/**"
- "web/.storybook/**"
- "web/package.json"
- "web/package-lock.json"
permissions:
contents: read
jobs:
Deploy-Storybook:
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup node
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm"
cache-dependency-path: ./web/package-lock.json
- name: Install dependencies
working-directory: web
run: npm ci
- name: Build Storybook
working-directory: web
run: npm run storybook:build
- name: Deploy to Vercel (Production)
working-directory: web
run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes
notify-slack-on-failure:
needs: Deploy-Storybook
if: always() && needs.Deploy-Storybook.result == 'failure'
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
with:
persist-credentials: false
sparse-checkout: .github/actions/slack-notify
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
failed-jobs: "• Deploy-Storybook"
title: "🚨 Storybook Deploy Failed"

View File

@@ -24,7 +24,7 @@ jobs:
persist-credentials: false
- name: Install the latest version of uv
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -7,9 +7,6 @@
AUTH_TYPE=basic
# Recommended for basic auth - used for signing password reset and verification tokens
# Generate a secure value with: openssl rand -hex 32
USER_AUTH_SECRET=""
DEV_MODE=true

View File

@@ -143,7 +143,6 @@ COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
COPY --chown=onyx:onyx ./scripts/reencrypt_secrets.py /app/scripts/reencrypt_secrets.py
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
# Run Craft template setup at build time when ENABLE_CRAFT=true

View File

@@ -1,51 +0,0 @@
"""add hierarchy_node_by_connector_credential_pair table
Revision ID: b5c4d7e8f9a1
Revises: a3b8d9e2f1c4
Create Date: 2026-03-04
"""
import sqlalchemy as sa
from alembic import op
revision = "b5c4d7e8f9a1"
down_revision = "a3b8d9e2f1c4"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"hierarchy_node_by_connector_credential_pair",
sa.Column("hierarchy_node_id", sa.Integer(), nullable=False),
sa.Column("connector_id", sa.Integer(), nullable=False),
sa.Column("credential_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["hierarchy_node_id"],
["hierarchy_node.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["connector_id", "credential_id"],
[
"connector_credential_pair.connector_id",
"connector_credential_pair.credential_id",
],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("hierarchy_node_id", "connector_id", "credential_id"),
)
op.create_index(
"ix_hierarchy_node_cc_pair_connector_credential",
"hierarchy_node_by_connector_credential_pair",
["connector_id", "credential_id"],
)
def downgrade() -> None:
op.drop_index(
"ix_hierarchy_node_cc_pair_connector_credential",
table_name="hierarchy_node_by_connector_credential_pair",
)
op.drop_table("hierarchy_node_by_connector_credential_pair")

View File

@@ -11,6 +11,7 @@ from sqlalchemy import text
from alembic import op
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from onyx.configs.app_configs import DB_READONLY_USER
from shared_configs.configs import MULTI_TENANT
# revision identifiers, used by Alembic.
@@ -21,52 +22,59 @@ depends_on = None
def upgrade() -> None:
# Enable pg_trgm extension if not already enabled
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
if MULTI_TENANT:
# Create the read-only db user if it does not already exist.
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
# Enable pg_trgm extension if not already enabled
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
op.execute(
text(
f"""
DO $$
BEGIN
-- Check if the read-only user already exists
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
-- Create the read-only user with the specified password
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
-- First revoke all privileges to ensure a clean slate
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Grant only the CONNECT privilege to allow the user to connect to the database
-- but not perform any operations without additional specific grants
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
END IF;
END
$$;
"""
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
# the user is created in the standard migration.
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
op.execute(
text(
f"""
DO $$
BEGIN
-- Check if the read-only user already exists
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
-- Create the read-only user with the specified password
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
-- First revoke all privileges to ensure a clean slate
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Grant only the CONNECT privilege to allow the user to connect to the database
-- but not perform any operations without additional specific grants
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)
)
def downgrade() -> None:
op.execute(
text(
f"""
DO $$
BEGIN
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
-- First revoke all privileges from the database
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Then revoke all privileges from the public schema
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
-- Then drop the user
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
END IF;
END
$$;
"""
if MULTI_TENANT:
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
# the user is dropped in the alembic_tenants migration.
op.execute(
text(
f"""
DO $$
BEGIN
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
-- First revoke all privileges from the database
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Then revoke all privileges from the public schema
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
-- Then drop the user
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)
)
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))

View File

@@ -9,15 +9,12 @@ from onyx.access.access import (
_get_access_for_documents as get_access_for_documents_without_groups,
)
from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_groups
from onyx.access.access import collect_user_file_access
from onyx.access.models import DocumentAccess
from onyx.access.utils import prefix_external_group
from onyx.access.utils import prefix_user_group
from onyx.db.document import get_document_sources
from onyx.db.document import get_documents_by_ids
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.user_file import fetch_user_files_with_access_relationships
from onyx.utils.logger import setup_logger
@@ -119,68 +116,6 @@ def _get_access_for_documents(
return access_map
def _collect_user_file_group_names(user_file: UserFile) -> set[str]:
"""Extract user-group names from the already-loaded Persona.groups
relationships on a UserFile (skipping deleted personas)."""
groups: set[str] = set()
for persona in user_file.assistants:
if persona.deleted:
continue
for group in persona.groups:
groups.add(group.name)
return groups
def get_access_for_user_files_impl(
user_file_ids: list[str],
db_session: Session,
) -> dict[str, DocumentAccess]:
"""EE version: extends the MIT user file ACL with user group names
from personas shared via user groups.
Uses a single DB query (via fetch_user_files_with_access_relationships)
that eagerly loads both the MIT-needed and EE-needed relationships.
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
DO NOT REMOVE."""
user_files = fetch_user_files_with_access_relationships(
user_file_ids, db_session, eager_load_groups=True
)
return build_access_for_user_files_impl(user_files)
def build_access_for_user_files_impl(
user_files: list[UserFile],
) -> dict[str, DocumentAccess]:
"""EE version: works on pre-loaded UserFile objects.
Expects Persona.groups to be eagerly loaded.
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
DO NOT REMOVE."""
result: dict[str, DocumentAccess] = {}
for user_file in user_files:
if user_file.user is None:
result[str(user_file.id)] = DocumentAccess.build(
user_emails=[],
user_groups=[],
is_public=True,
external_user_emails=[],
external_user_group_ids=[],
)
continue
emails, is_public = collect_user_file_access(user_file)
group_names = _collect_user_file_group_names(user_file)
result[str(user_file.id)] = DocumentAccess.build(
user_emails=list(emails),
user_groups=list(group_names),
is_public=is_public,
external_user_emails=[],
external_user_group_ids=[],
)
return result
def _get_acl_for_user(user: User, db_session: Session) -> set[str]:
"""Returns a list of ACL entries that the user has access to. This is meant to be
used downstream to filter out documents that the user does not have access to. The

View File

@@ -1,4 +1,3 @@
import os
from datetime import datetime
import jwt
@@ -21,13 +20,7 @@ logger = setup_logger()
def verify_auth_setting() -> None:
# All the Auth flows are valid for EE version, but warn about deprecated 'disabled'
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
if raw_auth_type == "disabled":
logger.warning(
"AUTH_TYPE='disabled' is no longer supported. "
"Using 'basic' instead. Please update your configuration."
)
# All the Auth flows are valid for EE version
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")

View File

@@ -18,7 +18,7 @@ from onyx.db.models import HierarchyNode
def _build_hierarchy_access_filter(
user_email: str,
user_email: str | None,
external_group_ids: list[str],
) -> ColumnElement[bool]:
"""Build SQLAlchemy filter for hierarchy node access.
@@ -43,7 +43,7 @@ def _build_hierarchy_access_filter(
def _get_accessible_hierarchy_nodes_for_source(
db_session: Session,
source: DocumentSource,
user_email: str,
user_email: str | None,
external_group_ids: list[str],
) -> list[HierarchyNode]:
"""

View File

@@ -7,7 +7,6 @@ from onyx.db.models import Persona
from onyx.db.models import Persona__User
from onyx.db.models import Persona__UserGroup
from onyx.db.notification import create_notification
from onyx.db.persona import mark_persona_user_files_for_sync
from onyx.server.features.persona.models import PersonaSharedNotificationData
@@ -27,9 +26,7 @@ def update_persona_access(
NOTE: Callers are responsible for committing."""
needs_sync = False
if is_public is not None:
needs_sync = True
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
if persona:
persona.is_public = is_public
@@ -38,7 +35,6 @@ def update_persona_access(
# and a non-empty list means "replace with these shares".
if user_ids is not None:
needs_sync = True
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
).delete(synchronize_session="fetch")
@@ -58,7 +54,6 @@ def update_persona_access(
)
if group_ids is not None:
needs_sync = True
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.persona_id == persona_id
).delete(synchronize_session="fetch")
@@ -68,7 +63,3 @@ def update_persona_access(
db_session.add(
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
)
# When sharing changes, user file ACLs need to be updated in the vector DB
if needs_sync:
mark_persona_user_files_for_sync(persona_id, db_session)

View File

@@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any
from jira import JIRA
from jira.exceptions import JIRAError
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.connectors.jira.utils import build_jira_client
@@ -9,107 +11,102 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_ATLASSIAN_ACCOUNT_TYPE = "atlassian"
_GROUP_MEMBER_PAGE_SIZE = 50
def _get_jira_group_members_email(
# The GET /group/member endpoint was introduced in Jira 6.0.
# Jira versions older than 6.0 do not have group management REST APIs at all.
_MIN_JIRA_VERSION_FOR_GROUP_MEMBER = "6.0"
def _fetch_group_member_page(
jira_client: JIRA,
group_name: str,
) -> list[str]:
"""Get all member emails for a Jira group.
start_at: int,
) -> dict[str, Any]:
"""Fetch a single page from the non-deprecated GET /group/member endpoint.
Filters out app accounts (bots, integrations) and only returns real user emails.
The old GET /group endpoint (used by jira_client.group_members()) is deprecated
and decommissioned in Jira Server 10.3+. This uses the replacement endpoint
directly via the library's internal _get_json helper, following the same pattern
as enhanced_search_ids / bulk_fetch_issues in connector.py.
There is an open PR to the library to switch to this endpoint since last year:
https://github.com/pycontribs/jira/pull/2356
so once it is merged and released, we can switch to using the library function.
"""
emails: list[str] = []
try:
# group_members returns an OrderedDict of account_id -> member_info
members = jira_client.group_members(group=group_name)
if not members:
logger.warning(f"No members found for group {group_name}")
return emails
for account_id, member_info in members.items():
# member_info is a dict with keys like 'fullname', 'email', 'active'
email = member_info.get("email")
# Skip "hidden" emails - these are typically app accounts
if email and email != "hidden":
emails.append(email)
else:
# For cloud, we might need to fetch user details separately
try:
user = jira_client.user(id=account_id)
# Skip app accounts (bots, integrations, etc.)
if hasattr(user, "accountType") and user.accountType == "app":
logger.info(
f"Skipping app account {account_id} for group {group_name}"
)
continue
if hasattr(user, "emailAddress") and user.emailAddress:
emails.append(user.emailAddress)
else:
logger.warning(f"User {account_id} has no email address")
except Exception as e:
logger.warning(
f"Could not fetch email for user {account_id} in group {group_name}: {e}"
)
except Exception as e:
logger.error(f"Error fetching members for group {group_name}: {e}")
return emails
return jira_client._get_json(
"group/member",
params={
"groupname": group_name,
"includeInactiveUsers": "false",
"startAt": start_at,
"maxResults": _GROUP_MEMBER_PAGE_SIZE,
},
)
except JIRAError as e:
if e.status_code == 404:
raise RuntimeError(
f"GET /group/member returned 404 for group '{group_name}'. "
f"This endpoint requires Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}+. "
f"If you are running a self-hosted Jira instance, please upgrade "
f"to at least Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}."
) from e
raise
def _build_group_member_email_map(
def _get_group_member_emails(
jira_client: JIRA,
) -> dict[str, set[str]]:
"""Build a map of group names to member emails."""
group_member_emails: dict[str, set[str]] = {}
group_name: str,
) -> set[str]:
"""Get all member emails for a single Jira group.
try:
# Get all groups from Jira - returns a list of group name strings
group_names = jira_client.groups()
Uses the non-deprecated GET /group/member endpoint which returns full user
objects including accountType, so we can filter out app/customer accounts
without making separate user() calls.
"""
emails: set[str] = set()
start_at = 0
if not group_names:
logger.warning("No groups found in Jira")
return group_member_emails
while True:
try:
page = _fetch_group_member_page(jira_client, group_name, start_at)
except Exception as e:
logger.error(f"Error fetching members for group {group_name}: {e}")
raise
logger.info(f"Found {len(group_names)} groups in Jira")
for group_name in group_names:
if not group_name:
members: list[dict[str, Any]] = page.get("values", [])
for member in members:
account_type = member.get("accountType")
# On Jira DC < 9.0, accountType is absent; include those users.
# On Cloud / DC 9.0+, filter to real user accounts only.
if account_type is not None and account_type != _ATLASSIAN_ACCOUNT_TYPE:
continue
member_emails = _get_jira_group_members_email(
jira_client=jira_client,
group_name=group_name,
)
if member_emails:
group_member_emails[group_name] = set(member_emails)
logger.debug(
f"Found {len(member_emails)} members for group {group_name}"
)
email = member.get("emailAddress")
if email:
emails.add(email)
else:
logger.debug(f"No members found for group {group_name}")
logger.warning(
f"Atlassian user {member.get('accountId', 'unknown')} "
f"in group {group_name} has no visible email address"
)
except Exception as e:
logger.error(f"Error building group member email map: {e}")
if page.get("isLast", True) or not members:
break
start_at += len(members)
return group_member_emails
return emails
def jira_group_sync(
tenant_id: str, # noqa: ARG001
cc_pair: ConnectorCredentialPair,
) -> Generator[ExternalUserGroup, None, None]:
"""
Sync Jira groups and their members.
"""Sync Jira groups and their members, yielding one group at a time.
This function fetches all groups from Jira and yields ExternalUserGroup
objects containing the group ID and member emails.
Streams group-by-group rather than accumulating all groups in memory.
"""
jira_base_url = cc_pair.connector.connector_specific_config.get("jira_base_url", "")
scoped_token = cc_pair.connector.connector_specific_config.get(
@@ -130,12 +127,26 @@ def jira_group_sync(
scoped_token=scoped_token,
)
group_member_email_map = _build_group_member_email_map(jira_client=jira_client)
if not group_member_email_map:
raise ValueError(f"No groups with members found for cc_pair_id={cc_pair.id}")
group_names = jira_client.groups()
if not group_names:
raise ValueError(f"No groups found for cc_pair_id={cc_pair.id}")
for group_id, group_member_emails in group_member_email_map.items():
yield ExternalUserGroup(
id=group_id,
user_emails=list(group_member_emails),
logger.info(f"Found {len(group_names)} groups in Jira")
for group_name in group_names:
if not group_name:
continue
member_emails = _get_group_member_emails(
jira_client=jira_client,
group_name=group_name,
)
if not member_emails:
logger.debug(f"No members found for group {group_name}")
continue
logger.debug(f"Found {len(member_emails)} members for group {group_name}")
yield ExternalUserGroup(
id=group_name,
user_emails=list(member_emails),
)

View File

@@ -14,91 +14,67 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
@lru_cache(maxsize=2)
@lru_cache(maxsize=1)
def _get_trimmed_key(key: str) -> bytes:
encoded_key = key.encode()
key_length = len(encoded_key)
if key_length < 16:
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
elif key_length > 32:
key = key[:32]
elif key_length not in (16, 24, 32):
valid_lengths = [16, 24, 32]
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
# Trim to the largest valid AES key size that fits
valid_lengths = [32, 24, 16]
for size in valid_lengths:
if key_length >= size:
return encoded_key[:size]
raise AssertionError("unreachable")
return encoded_key
def _encrypt_string(input_str: str, key: str | None = None) -> bytes:
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
if not effective_key:
def _encrypt_string(input_str: str) -> bytes:
if not ENCRYPTION_KEY_SECRET:
return input_str.encode()
trimmed = _get_trimmed_key(effective_key)
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
iv = urandom(16)
padder = padding.PKCS7(algorithms.AES.block_size).padder()
padded_data = padder.update(input_str.encode()) + padder.finalize()
cipher = Cipher(algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend())
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
encryptor = cipher.encryptor()
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
return iv + encrypted_data
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str:
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
if not effective_key:
def _decrypt_bytes(input_bytes: bytes) -> str:
if not ENCRYPTION_KEY_SECRET:
return input_bytes.decode()
trimmed = _get_trimmed_key(effective_key)
try:
iv = input_bytes[:16]
encrypted_data = input_bytes[16:]
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
iv = input_bytes[:16]
encrypted_data = input_bytes[16:]
cipher = Cipher(
algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend()
)
decryptor = cipher.decryptor()
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
decryptor = cipher.decryptor()
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
return decrypted_data.decode()
except (ValueError, UnicodeDecodeError):
if key is not None:
# Explicit key was provided — don't fall back silently
raise
# Read path: attempt raw UTF-8 decode as a fallback for legacy data.
# Does NOT handle data encrypted with a different key — that
# ciphertext is not valid UTF-8 and will raise below.
logger.warning(
"AES decryption failed — falling back to raw decode. "
"Run the re-encrypt secrets script to rotate to the current key."
)
try:
return input_bytes.decode()
except UnicodeDecodeError:
raise ValueError(
"Data is not valid UTF-8 — likely encrypted with a different key. "
"Run the re-encrypt secrets script to rotate to the current key."
) from None
return decrypted_data.decode()
def encrypt_string_to_bytes(input_str: str, key: str | None = None) -> bytes:
def encrypt_string_to_bytes(input_str: str) -> bytes:
versioned_encryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_encrypt_string"
)
return versioned_encryption_fn(input_str, key=key)
return versioned_encryption_fn(input_str)
def decrypt_bytes_to_string(input_bytes: bytes, key: str | None = None) -> str:
def decrypt_bytes_to_string(input_bytes: bytes) -> str:
versioned_decryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_decrypt_bytes"
)
return versioned_decryption_fn(input_bytes, key=key)
return versioned_decryption_fn(input_bytes)
def test_encryption() -> None:

View File

@@ -1,6 +1,7 @@
from collections.abc import Callable
from typing import cast
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.access.models import DocumentAccess
@@ -11,7 +12,6 @@ from onyx.db.document import get_access_info_for_document
from onyx.db.document import get_access_info_for_documents
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.user_file import fetch_user_files_with_access_relationships
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -132,61 +132,19 @@ def get_access_for_user_files(
user_file_ids: list[str],
db_session: Session,
) -> dict[str, DocumentAccess]:
versioned_fn = fetch_versioned_implementation(
"onyx.access.access", "get_access_for_user_files_impl"
user_files = (
db_session.query(UserFile)
.options(joinedload(UserFile.user)) # Eager load the user relationship
.filter(UserFile.id.in_(user_file_ids))
.all()
)
return versioned_fn(user_file_ids, db_session)
def get_access_for_user_files_impl(
user_file_ids: list[str],
db_session: Session,
) -> dict[str, DocumentAccess]:
user_files = fetch_user_files_with_access_relationships(user_file_ids, db_session)
return build_access_for_user_files_impl(user_files)
def build_access_for_user_files(
user_files: list[UserFile],
) -> dict[str, DocumentAccess]:
"""Compute access from pre-loaded UserFile objects (with relationships).
Callers must ensure UserFile.user, Persona.users, and Persona.user are
eagerly loaded (and Persona.groups for the EE path)."""
versioned_fn = fetch_versioned_implementation(
"onyx.access.access", "build_access_for_user_files_impl"
)
return versioned_fn(user_files)
def build_access_for_user_files_impl(
user_files: list[UserFile],
) -> dict[str, DocumentAccess]:
result: dict[str, DocumentAccess] = {}
for user_file in user_files:
emails, is_public = collect_user_file_access(user_file)
result[str(user_file.id)] = DocumentAccess.build(
user_emails=list(emails),
return {
str(user_file.id): DocumentAccess.build(
user_emails=[user_file.user.email] if user_file.user else [],
user_groups=[],
is_public=is_public,
is_public=True if user_file.user is None else False,
external_user_emails=[],
external_user_group_ids=[],
)
return result
def collect_user_file_access(user_file: UserFile) -> tuple[set[str], bool]:
"""Collect all user emails that should have access to this user file.
Includes the owner plus any users who have access via shared personas.
Returns (emails, is_public)."""
emails: set[str] = {user_file.user.email}
is_public = False
for persona in user_file.assistants:
if persona.deleted:
continue
if persona.is_public:
is_public = True
if persona.user_id is not None and persona.user:
emails.add(persona.user.email)
for shared_user in persona.users:
emails.add(shared_user.email)
return emails, is_public
for user_file in user_files
}

View File

@@ -1,5 +1,4 @@
import json
import os
import random
import secrets
import string
@@ -146,22 +145,10 @@ def is_user_admin(user: User) -> bool:
def verify_auth_setting() -> None:
"""Log warnings for AUTH_TYPE issues.
This only runs on app startup not during migrations/scripts.
"""
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
if raw_auth_type == "cloud":
if AUTH_TYPE == AuthType.CLOUD:
raise ValueError(
"'cloud' is not a valid auth type for self-hosted deployments."
f"{AUTH_TYPE.value} is not a valid auth type for self-hosted deployments."
)
if raw_auth_type == "disabled":
logger.warning(
"AUTH_TYPE='disabled' is no longer supported. "
"Using 'basic' instead. Please update your configuration."
)
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")

View File

@@ -115,6 +115,8 @@ def _extract_from_batch(
for item in doc_list:
if isinstance(item, HierarchyNode):
hierarchy_nodes.append(item)
if item.raw_node_id not in ids:
ids[item.raw_node_id] = None
elif isinstance(item, ConnectorFailure):
failed_id = _get_failure_id(item)
if failed_id:
@@ -123,7 +125,8 @@ def _extract_from_batch(
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
)
else:
ids[item.id] = item.parent_hierarchy_raw_node_id
parent_raw = getattr(item, "parent_hierarchy_raw_node_id", None)
ids[item.id] = parent_raw
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
@@ -189,7 +192,9 @@ def extract_ids_from_runnable_connector(
batch_ids = batch_result.raw_id_to_parent
batch_nodes = batch_result.hierarchy_nodes
doc_batch_processing_func(batch_ids)
all_raw_id_to_parent.update(batch_ids)
for k, v in batch_ids.items():
if v is not None or k not in all_raw_id_to_parent:
all_raw_id_to_parent[k] = v
all_hierarchy_nodes.extend(batch_nodes)
if callback:

View File

@@ -40,7 +40,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.models import ConnectorCredentialPair
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
@@ -290,14 +289,6 @@ def _run_hierarchy_extraction(
is_connector_public=is_connector_public,
)
upsert_hierarchy_node_cc_pair_entries(
db_session=db_session,
hierarchy_node_ids=[n.id for n in upserted_nodes],
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
commit=True,
)
# Cache in Redis for fast ancestor resolution
cache_entries = [
HierarchyNodeCacheEntry.from_db_model(node) for node in upserted_nodes

View File

@@ -48,15 +48,10 @@ from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries
from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import HierarchyNode as DBHierarchyNode
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.tag import delete_orphan_tags__no_commit
@@ -65,7 +60,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import evict_hierarchy_nodes_from_cache
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
@@ -585,12 +579,11 @@ def connector_pruning_generator_task(
source = cc_pair.connector.source
redis_client = get_redis_client(tenant_id=tenant_id)
ensure_source_node_exists(redis_client, db_session, source)
upserted_nodes: list[DBHierarchyNode] = []
if extraction_result.hierarchy_nodes:
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
ensure_source_node_exists(redis_client, db_session, source)
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=extraction_result.hierarchy_nodes,
@@ -599,14 +592,6 @@ def connector_pruning_generator_task(
is_connector_public=is_connector_public,
)
upsert_hierarchy_node_cc_pair_entries(
db_session=db_session,
hierarchy_node_ids=[n.id for n in upserted_nodes],
connector_id=connector_id,
credential_id=credential_id,
commit=True,
)
cache_entries = [
HierarchyNodeCacheEntry.from_db_model(node)
for node in upserted_nodes
@@ -622,6 +607,7 @@ def connector_pruning_generator_task(
f"hierarchy nodes for cc_pair={cc_pair_id}"
)
ensure_source_node_exists(redis_client, db_session, source)
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
# and bulk-update documents, mirroring the docfetching resolution
_resolve_and_update_document_parents(
@@ -678,43 +664,6 @@ def connector_pruning_generator_task(
)
redis_connector.prune.generator_complete = tasks_generated
# --- Hierarchy node pruning ---
live_node_ids = {n.id for n in upserted_nodes}
stale_removed = remove_stale_hierarchy_node_cc_pair_entries(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
live_hierarchy_node_ids=live_node_ids,
commit=True,
)
deleted_raw_ids = delete_orphaned_hierarchy_nodes(
db_session=db_session,
source=source,
commit=True,
)
reparented_nodes = reparent_orphaned_hierarchy_nodes(
db_session=db_session,
source=source,
commit=True,
)
if deleted_raw_ids:
evict_hierarchy_nodes_from_cache(redis_client, source, deleted_raw_ids)
if reparented_nodes:
reparented_cache_entries = [
HierarchyNodeCacheEntry.from_db_model(node)
for node in reparented_nodes
]
cache_hierarchy_nodes_batch(
redis_client, source, reparented_cache_entries
)
if stale_removed or deleted_raw_ids or reparented_nodes:
task_logger.info(
f"Hierarchy node pruning: cc_pair={cc_pair_id} "
f"stale_entries_removed={stale_removed} "
f"nodes_deleted={len(deleted_raw_ids)} "
f"nodes_reparented={len(reparented_nodes)}"
)
except Exception as e:
task_logger.exception(
f"Pruning exceptioned: cc_pair={cc_pair_id} "

View File

@@ -12,9 +12,9 @@ from redis import Redis
from redis.lock import Lock as RedisLock
from retry import retry
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.access.access import build_access_for_user_files
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
@@ -43,9 +43,7 @@ from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from onyx.db.search_settings import get_active_search_settings
from onyx.db.search_settings import get_active_search_settings_list
from onyx.db.user_file import fetch_user_files_with_access_relationships
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.file_store.file_store import get_default_file_store
@@ -56,7 +54,6 @@ from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAd
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.variable_functionality import global_version
def _as_uuid(value: str | UUID) -> UUID:
@@ -794,12 +791,11 @@ def project_sync_user_file_impl(
try:
with get_session_with_current_tenant() as db_session:
user_files = fetch_user_files_with_access_relationships(
[user_file_id],
db_session,
eager_load_groups=global_version.is_ee_version(),
)
user_file = user_files[0] if user_files else None
user_file = db_session.execute(
select(UserFile)
.where(UserFile.id == _as_uuid(user_file_id))
.options(selectinload(UserFile.assistants))
).scalar_one_or_none()
if not user_file:
task_logger.info(
f"project_sync_user_file_impl - User file not found id={user_file_id}"
@@ -827,21 +823,12 @@ def project_sync_user_file_impl(
project_ids = [project.id for project in user_file.projects]
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
file_id_str = str(user_file.id)
access_map = build_access_for_user_files([user_file])
access = access_map.get(file_id_str)
for retry_document_index in retry_document_indices:
retry_document_index.update_single(
doc_id=file_id_str,
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
fields=(
VespaDocumentFields(access=access)
if access is not None
else None
),
fields=None,
user_fields=VespaDocumentUserFields(
user_projects=project_ids,
personas=persona_ids,

View File

@@ -45,7 +45,6 @@ from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.enums import ProcessingMode
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.index_attempt import get_index_attempt
@@ -588,14 +587,6 @@ def connector_document_extraction(
is_connector_public=is_connector_public,
)
upsert_hierarchy_node_cc_pair_entries(
db_session=db_session,
hierarchy_node_ids=[n.id for n in upserted_nodes],
connector_id=db_connector.id,
credential_id=db_credential.id,
commit=True,
)
# Cache in Redis for fast ancestor resolution during doc processing
redis_client = get_redis_client(tenant_id=tenant_id)
cache_entries = [

View File

@@ -15,7 +15,6 @@ from onyx.chat.citation_processor import DynamicCitationProcessor
from onyx.chat.emitter import Emitter
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import LlmStepResult
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
from onyx.configs.constants import MessageType
@@ -55,7 +54,6 @@ from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.tools.models import ToolCallKickoff
from onyx.tracing.framework.create import generation_span
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.jsonriver import Parser
from onyx.utils.logger import setup_logger
from onyx.utils.postgres_sanitization import sanitize_string
from onyx.utils.text_processing import find_all_json_objects
@@ -1011,7 +1009,6 @@ def run_llm_step_pkt_generator(
)
id_to_tool_call_map: dict[int, dict[str, Any]] = {}
arg_parsers: dict[int, Parser] = {}
reasoning_start = False
answer_start = False
accumulated_reasoning = ""
@@ -1218,14 +1215,7 @@ def run_llm_step_pkt_generator(
yield from _close_reasoning_if_active()
for tool_call_delta in delta.tool_calls:
# maybe_emit depends and update being called first and attaching the delta
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
yield from maybe_emit_argument_delta(
tool_calls_in_progress=id_to_tool_call_map,
tool_call_delta=tool_call_delta,
placement=_current_placement(),
parsers=arg_parsers,
)
# Flush any tail text buffered while checking for split "<function_calls" markers.
filtered_content_tail = xml_tool_call_content_filter.flush()

View File

@@ -1,77 +0,0 @@
from collections.abc import Generator
from collections.abc import Mapping
from typing import Any
from typing import Type
from onyx.llm.model_response import ChatCompletionDeltaToolCall
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
from onyx.tools.built_in_tools import TOOL_NAME_TO_CLASS
from onyx.tools.interface import Tool
from onyx.utils.jsonriver import Parser
def _get_tool_class(
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
tool_call_delta: ChatCompletionDeltaToolCall,
) -> Type[Tool] | None:
"""Look up the Tool subclass for a streaming tool call delta."""
tool_name = tool_calls_in_progress.get(tool_call_delta.index, {}).get("name")
if not tool_name:
return None
return TOOL_NAME_TO_CLASS.get(tool_name)
def maybe_emit_argument_delta(
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
tool_call_delta: ChatCompletionDeltaToolCall,
placement: Placement,
parsers: dict[int, Parser],
) -> Generator[Packet, None, None]:
"""Emit decoded tool-call argument deltas to the frontend.
Uses a ``jsonriver.Parser`` per tool-call index to incrementally parse
the JSON argument string and extract only the newly-appended content
for each string-valued argument.
NOTE: Non-string arguments (numbers, booleans, null, arrays, objects)
are skipped — they are available in the final tool-call kickoff packet.
``parsers`` is a mutable dict keyed by tool-call index. A new
``Parser`` is created automatically for each new index.
"""
tool_cls = _get_tool_class(tool_calls_in_progress, tool_call_delta)
if not tool_cls or not tool_cls.should_emit_argument_deltas():
return
fn = tool_call_delta.function
delta_fragment = fn.arguments if fn else None
if not delta_fragment:
return
idx = tool_call_delta.index
if idx not in parsers:
parsers[idx] = Parser()
parser = parsers[idx]
deltas = parser.feed(delta_fragment)
argument_deltas: dict[str, str] = {}
for delta in deltas:
if isinstance(delta, dict):
for key, value in delta.items():
if isinstance(value, str):
argument_deltas[key] = argument_deltas.get(key, "") + value
if not argument_deltas:
return
tc_data = tool_calls_in_progress[tool_call_delta.index]
yield Packet(
placement=placement,
obj=ToolCallArgumentDelta(
tool_type=tc_data.get("name", ""),
argument_deltas=argument_deltas,
),
)

View File

@@ -68,10 +68,6 @@ FILE_TOKEN_COUNT_THRESHOLD = int(
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
)
# Maximum upload size for a single user file (chat/projects) in MB.
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
# If set to true, will show extra/uncommon connectors in the "Other" category
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
@@ -96,12 +92,19 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
#####
# Auth Configs
#####
# Silently default to basic - warnings/errors logged in verify_auth_setting()
# which only runs on app startup, not during migrations/scripts
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
if _auth_type_str in [auth_type.value for auth_type in AuthType]:
# Upgrades users from disabled auth to basic auth and shows warning.
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
if _auth_type_str == "disabled":
logger.warning(
"AUTH_TYPE='disabled' is no longer supported. "
"Defaulting to 'basic'. Please update your configuration. "
"Your existing data will be migrated automatically."
)
_auth_type_str = AuthType.BASIC.value
try:
AUTH_TYPE = AuthType(_auth_type_str)
else:
except ValueError:
logger.error(f"Invalid AUTH_TYPE: {_auth_type_str}. Defaulting to 'basic'.")
AUTH_TYPE = AuthType.BASIC
PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 8))
@@ -204,12 +207,6 @@ JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
if AUTH_TYPE == AuthType.BASIC and not USER_AUTH_SECRET:
logger.warning(
"USER_AUTH_SECRET is not set. This is required for secure password reset "
"and email verification tokens. Please set USER_AUTH_SECRET in production."
)
# Duration (in seconds) for which the FastAPI Users JWT token remains valid in the user's browser.
# By default, this is set to match the Redis expiry time for consistency.
AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(

View File

@@ -44,7 +44,6 @@ from onyx.connectors.google_utils.shared_constants import (
from onyx.db.credentials import update_credential_json
from onyx.db.models import User
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import unwrap_str
from onyx.server.documents.models import CredentialBase
from onyx.server.documents.models import GoogleAppCredentials
from onyx.server.documents.models import GoogleServiceAccountKey
@@ -90,7 +89,7 @@ def _get_current_oauth_user(creds: OAuthCredentials, source: DocumentSource) ->
def verify_csrf(credential_id: int, state: str) -> None:
csrf = unwrap_str(get_kv_store().load(KV_CRED_KEY.format(str(credential_id))))
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Google Drive Connector callback does not match expected"
@@ -179,9 +178,7 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
params = parse_qs(parsed_url.query)
get_kv_store().store(
KV_CRED_KEY.format(credential_id),
{"value": params.get("state", [None])[0]},
encrypt=True,
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
)
return str(auth_url)

View File

@@ -33,6 +33,7 @@ from office365.runtime.queries.client_query import ClientQuery # type: ignore[i
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from pydantic import BaseModel
from pydantic import Field
from requests.exceptions import HTTPError
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
@@ -268,6 +269,32 @@ class SizeCapExceeded(Exception):
"""Exception raised when the size cap is exceeded."""
def _log_and_raise_for_status(response: requests.Response) -> None:
"""Log the response text and raise for status."""
try:
response.raise_for_status()
except Exception:
logger.error(f"HTTP request failed: {response.text}")
raise
GRAPH_INVALID_REQUEST_CODE = "invalidRequest"
def _is_graph_invalid_request(response: requests.Response) -> bool:
"""Return True if the response body is the generic Graph API
``{"error": {"code": "invalidRequest", "message": "Invalid request"}}``
shape. This particular error has no actionable inner error code and is
returned by the site-pages endpoint when a page has a corrupt canvas layout
(e.g. duplicate web-part IDs — see SharePoint/sp-dev-docs#8822)."""
try:
body = response.json()
except Exception:
return False
error = body.get("error", {})
return error.get("code") == GRAPH_INVALID_REQUEST_CODE
def load_certificate_from_pfx(pfx_data: bytes, password: str) -> CertificateData | None:
"""Load certificate from .pfx file for MSAL authentication"""
try:
@@ -344,7 +371,7 @@ def _probe_remote_size(url: str, timeout: int) -> int | None:
"""Determine remote size using HEAD or a range GET probe. Returns None if unknown."""
try:
head_resp = requests.head(url, timeout=timeout, allow_redirects=True)
head_resp.raise_for_status()
_log_and_raise_for_status(head_resp)
cl = head_resp.headers.get("Content-Length")
if cl and cl.isdigit():
return int(cl)
@@ -359,7 +386,7 @@ def _probe_remote_size(url: str, timeout: int) -> int | None:
timeout=timeout,
stream=True,
) as range_resp:
range_resp.raise_for_status()
_log_and_raise_for_status(range_resp)
cr = range_resp.headers.get("Content-Range") # e.g., "bytes 0-0/12345"
if cr and "/" in cr:
total = cr.split("/")[-1]
@@ -384,7 +411,7 @@ def _download_with_cap(url: str, timeout: int, cap: int) -> bytes:
- Returns the full bytes if the content fits within `cap`.
"""
with requests.get(url, stream=True, timeout=timeout) as resp:
resp.raise_for_status()
_log_and_raise_for_status(resp)
# If the server provides Content-Length, prefer an early decision.
cl_header = resp.headers.get("Content-Length")
@@ -428,7 +455,7 @@ def _download_via_graph_api(
with requests.get(
url, headers=headers, stream=True, timeout=REQUEST_TIMEOUT_SECONDS
) as resp:
resp.raise_for_status()
_log_and_raise_for_status(resp)
buf = io.BytesIO()
for chunk in resp.iter_content(64 * 1024):
if not chunk:
@@ -1238,26 +1265,135 @@ class SharepointConnector(
site.execute_query()
site_id = site.id
page_url: str | None = (
f"{self.graph_api_base}/sites/{site_id}" f"/pages/microsoft.graph.sitePage"
site_pages_base = (
f"{self.graph_api_base}/sites/{site_id}/pages/microsoft.graph.sitePage"
)
page_url: str | None = site_pages_base
params: dict[str, str] | None = {"$expand": "canvasLayout"}
total_yielded = 0
yielded_ids: set[str] = set()
while page_url:
data = self._graph_api_get_json(page_url, params)
try:
data = self._graph_api_get_json(page_url, params)
except HTTPError as e:
if e.response is not None and e.response.status_code == 404:
logger.warning(f"Site page not found: {page_url}")
break
if (
e.response is not None
and e.response.status_code == 400
and _is_graph_invalid_request(e.response)
):
logger.warning(
f"$expand=canvasLayout on the LIST endpoint returned 400 "
f"for site {site_descriptor.url}. Falling back to "
f"per-page expansion."
)
yield from self._fetch_site_pages_individually(
site_pages_base, start, end, skip_ids=yielded_ids
)
return
raise
params = None # nextLink already embeds query params
for page in data.get("value", []):
if not _site_page_in_time_window(page, start, end):
continue
total_yielded += 1
page_id = page.get("id")
if page_id:
yielded_ids.add(page_id)
yield page
page_url = data.get("@odata.nextLink")
logger.debug(f"Yielded {total_yielded} site pages for {site_descriptor.url}")
def _fetch_site_pages_individually(
self,
site_pages_base: str,
start: datetime | None = None,
end: datetime | None = None,
skip_ids: set[str] | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Fallback for _fetch_site_pages: list pages without $expand, then
expand canvasLayout on each page individually.
The Graph API's LIST endpoint can return 400 when $expand=canvasLayout
is used and *any* page in the site has a corrupt canvas layout (e.g.
duplicate web part IDs — see SharePoint/sp-dev-docs#8822). Since the
LIST expansion is all-or-nothing, a single bad page poisons the entire
response. This method works around it by fetching metadata first, then
expanding each page individually so only the broken page loses its
canvas content.
``skip_ids`` contains page IDs already yielded by the caller before the
fallback was triggered, preventing duplicates.
"""
page_url: str | None = site_pages_base
total_yielded = 0
_skip_ids = skip_ids or set()
while page_url:
try:
data = self._graph_api_get_json(page_url)
except HTTPError as e:
if e.response is not None and e.response.status_code == 404:
break
raise
for page in data.get("value", []):
if not _site_page_in_time_window(page, start, end):
continue
page_id = page.get("id")
if page_id and page_id in _skip_ids:
continue
if not page_id:
total_yielded += 1
yield page
continue
expanded = self._try_expand_single_page(site_pages_base, page_id, page)
total_yielded += 1
yield expanded
page_url = data.get("@odata.nextLink")
logger.debug(
f"Yielded {total_yielded} site pages (per-page expansion fallback)"
)
def _try_expand_single_page(
self,
site_pages_base: str,
page_id: str,
fallback_page: dict[str, Any],
) -> dict[str, Any]:
"""Try to GET a single page with $expand=canvasLayout. On 400, return
the metadata-only fallback so the page is still indexed (without canvas
content)."""
pages_collection = site_pages_base.removesuffix("/microsoft.graph.sitePage")
single_url = f"{pages_collection}/{page_id}/microsoft.graph.sitePage"
try:
return self._graph_api_get_json(single_url, {"$expand": "canvasLayout"})
except HTTPError as e:
if (
e.response is not None
and e.response.status_code == 400
and _is_graph_invalid_request(e.response)
):
page_name = fallback_page.get("name", page_id)
logger.warning(
f"$expand=canvasLayout failed for page '{page_name}' "
f"({page_id}). Indexing metadata only."
)
return fallback_page
raise
def _acquire_token(self) -> dict[str, Any]:
"""
Acquire token via MSAL
@@ -1309,7 +1445,7 @@ class SharepointConnector(
access_token = self._get_graph_access_token()
headers = {"Authorization": f"Bearer {access_token}"}
continue
response.raise_for_status()
_log_and_raise_for_status(response)
return response.json()
except (requests.ConnectionError, requests.Timeout):
if attempt < GRAPH_API_MAX_RETRIES:

View File

@@ -2,10 +2,7 @@
from collections import defaultdict
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
@@ -13,7 +10,6 @@ from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode
from onyx.db.enums import HierarchyNodeType
from onyx.db.models import Document
from onyx.db.models import HierarchyNode
from onyx.db.models import HierarchyNodeByConnectorCredentialPair
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -462,7 +458,7 @@ def get_all_hierarchy_nodes_for_source(
def _get_accessible_hierarchy_nodes_for_source(
db_session: Session,
source: DocumentSource,
user_email: str, # noqa: ARG001
user_email: str | None, # noqa: ARG001
external_group_ids: list[str], # noqa: ARG001
) -> list[HierarchyNode]:
"""
@@ -489,7 +485,7 @@ def _get_accessible_hierarchy_nodes_for_source(
def get_accessible_hierarchy_nodes_for_source(
db_session: Session,
source: DocumentSource,
user_email: str,
user_email: str | None,
external_group_ids: list[str],
) -> list[HierarchyNode]:
"""
@@ -624,154 +620,3 @@ def update_hierarchy_node_permissions(
db_session.flush()
return True
def upsert_hierarchy_node_cc_pair_entries(
db_session: Session,
hierarchy_node_ids: list[int],
connector_id: int,
credential_id: int,
commit: bool = True,
) -> None:
"""Insert rows into HierarchyNodeByConnectorCredentialPair, ignoring conflicts.
This records that the given cc_pair "owns" these hierarchy nodes. Used by
indexing, pruning, and hierarchy-fetching paths.
"""
if not hierarchy_node_ids:
return
_M = HierarchyNodeByConnectorCredentialPair
stmt = pg_insert(_M).values(
[
{
_M.hierarchy_node_id: node_id,
_M.connector_id: connector_id,
_M.credential_id: credential_id,
}
for node_id in hierarchy_node_ids
]
)
stmt = stmt.on_conflict_do_nothing()
db_session.execute(stmt)
if commit:
db_session.commit()
else:
db_session.flush()
def remove_stale_hierarchy_node_cc_pair_entries(
db_session: Session,
connector_id: int,
credential_id: int,
live_hierarchy_node_ids: set[int],
commit: bool = True,
) -> int:
"""Delete join-table rows for this cc_pair that are NOT in the live set.
If ``live_hierarchy_node_ids`` is empty ALL rows for the cc_pair are deleted
(i.e. the connector no longer has any hierarchy nodes). Callers that want a
no-op when there are no live nodes must guard before calling.
Returns the number of deleted rows.
"""
stmt = delete(HierarchyNodeByConnectorCredentialPair).where(
HierarchyNodeByConnectorCredentialPair.connector_id == connector_id,
HierarchyNodeByConnectorCredentialPair.credential_id == credential_id,
)
if live_hierarchy_node_ids:
stmt = stmt.where(
HierarchyNodeByConnectorCredentialPair.hierarchy_node_id.notin_(
live_hierarchy_node_ids
)
)
result: CursorResult = db_session.execute(stmt) # type: ignore[assignment]
deleted = result.rowcount
if commit:
db_session.commit()
elif deleted:
db_session.flush()
return deleted
def delete_orphaned_hierarchy_nodes(
db_session: Session,
source: DocumentSource,
commit: bool = True,
) -> list[str]:
"""Delete hierarchy nodes for a source that have zero cc_pair associations.
SOURCE-type nodes are excluded (they are synthetic roots).
Returns the list of raw_node_ids that were deleted (for cache eviction).
"""
# Find orphaned nodes: no rows in the join table
orphan_stmt = (
select(HierarchyNode.id, HierarchyNode.raw_node_id)
.outerjoin(
HierarchyNodeByConnectorCredentialPair,
HierarchyNode.id
== HierarchyNodeByConnectorCredentialPair.hierarchy_node_id,
)
.where(
HierarchyNode.source == source,
HierarchyNode.node_type != HierarchyNodeType.SOURCE,
HierarchyNodeByConnectorCredentialPair.hierarchy_node_id.is_(None),
)
)
orphans = db_session.execute(orphan_stmt).all()
if not orphans:
return []
orphan_ids = [row[0] for row in orphans]
deleted_raw_ids = [row[1] for row in orphans]
db_session.execute(delete(HierarchyNode).where(HierarchyNode.id.in_(orphan_ids)))
if commit:
db_session.commit()
else:
db_session.flush()
return deleted_raw_ids
def reparent_orphaned_hierarchy_nodes(
db_session: Session,
source: DocumentSource,
commit: bool = True,
) -> list[HierarchyNode]:
"""Re-parent hierarchy nodes whose parent_id is NULL to the SOURCE node.
After pruning deletes stale nodes, their former children get parent_id=NULL
via the SET NULL cascade. This function points them back to the SOURCE root.
Returns the reparented HierarchyNode objects (with updated parent_id)
so callers can refresh downstream caches.
"""
source_node = get_source_hierarchy_node(db_session, source)
if not source_node:
return []
stmt = select(HierarchyNode).where(
HierarchyNode.source == source,
HierarchyNode.parent_id.is_(None),
HierarchyNode.node_type != HierarchyNodeType.SOURCE,
)
orphans = list(db_session.execute(stmt).scalars().all())
if not orphans:
return []
for node in orphans:
node.parent_id = source_node.id
if commit:
db_session.commit()
else:
db_session.flush()
return orphans

View File

@@ -25,6 +25,7 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import SyncModelEntry
from onyx.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
@@ -369,9 +370,9 @@ def upsert_llm_provider(
def sync_model_configurations(
db_session: Session,
provider_name: str,
models: list[dict],
models: list[SyncModelEntry],
) -> int:
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama).
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama, etc.).
This inserts NEW models from the source API without overwriting existing ones.
User preferences (is_visible, max_input_tokens) are preserved for existing models.
@@ -379,7 +380,7 @@ def sync_model_configurations(
Args:
db_session: Database session
provider_name: Name of the LLM provider
models: List of model dicts with keys: name, display_name, max_input_tokens, supports_image_input
models: List of SyncModelEntry objects describing the fetched models
Returns:
Number of new models added
@@ -393,21 +394,20 @@ def sync_model_configurations(
new_count = 0
for model in models:
model_name = model["name"]
if model_name not in existing_names:
if model.name not in existing_names:
# Insert new model with is_visible=False (user must explicitly enable)
supported_flows = [LLMModelFlowType.CHAT]
if model.get("supports_image_input", False):
if model.supports_image_input:
supported_flows.append(LLMModelFlowType.VISION)
insert_new_model_configuration__no_commit(
db_session=db_session,
llm_provider_id=provider.id,
model_name=model_name,
model_name=model.name,
supported_flows=supported_flows,
is_visible=False,
max_input_tokens=model.get("max_input_tokens"),
display_name=model.get("display_name"),
max_input_tokens=model.max_input_tokens,
display_name=model.display_name,
)
new_count += 1

View File

@@ -25,7 +25,6 @@ from sqlalchemy import desc
from sqlalchemy import Enum
from sqlalchemy import Float
from sqlalchemy import ForeignKey
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func
from sqlalchemy import Index
from sqlalchemy import Integer
@@ -37,11 +36,9 @@ from sqlalchemy import Text
from sqlalchemy import text
from sqlalchemy import UniqueConstraint
from sqlalchemy.dialects import postgresql
from sqlalchemy import event
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import Mapper
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.types import LargeBinary
@@ -120,50 +117,10 @@ class Base(DeclarativeBase):
__abstract__ = True
class _EncryptedBase(TypeDecorator):
"""Base for encrypted column types that wrap values in SensitiveValue."""
class EncryptedString(TypeDecorator):
impl = LargeBinary
# This type's behavior is fully deterministic and doesn't depend on any external factors.
cache_ok = True
_is_json: bool = False
def wrap_raw(self, value: Any) -> SensitiveValue:
"""Encrypt a raw value and wrap it in SensitiveValue.
Called by the attribute set event so the Python-side type is always
SensitiveValue, regardless of whether the value was loaded from the DB
or assigned in application code.
"""
if self._is_json:
if not isinstance(value, dict):
raise TypeError(
f"EncryptedJson column expected dict, got {type(value).__name__}"
)
raw_str = json.dumps(value)
else:
if not isinstance(value, str):
raise TypeError(
f"EncryptedString column expected str, got {type(value).__name__}"
)
raw_str = value
return SensitiveValue(
encrypted_bytes=encrypt_string_to_bytes(raw_str),
decrypt_fn=decrypt_bytes_to_string,
is_json=self._is_json,
)
def compare_values(self, x: Any, y: Any) -> bool:
if x is None or y is None:
return x == y
if isinstance(x, SensitiveValue):
x = x.get_value(apply_mask=False)
if isinstance(y, SensitiveValue):
y = y.get_value(apply_mask=False)
return x == y
class EncryptedString(_EncryptedBase):
_is_json: bool = False
def process_bind_param(
self, value: str | SensitiveValue[str] | None, dialect: Dialect # noqa: ARG002
@@ -187,9 +144,20 @@ class EncryptedString(_EncryptedBase):
)
return None
def compare_values(self, x: Any, y: Any) -> bool:
if x is None or y is None:
return x == y
if isinstance(x, SensitiveValue):
x = x.get_value(apply_mask=False)
if isinstance(y, SensitiveValue):
y = y.get_value(apply_mask=False)
return x == y
class EncryptedJson(_EncryptedBase):
_is_json: bool = True
class EncryptedJson(TypeDecorator):
impl = LargeBinary
# This type's behavior is fully deterministic and doesn't depend on any external factors.
cache_ok = True
def process_bind_param(
self,
@@ -197,7 +165,9 @@ class EncryptedJson(_EncryptedBase):
dialect: Dialect, # noqa: ARG002
) -> bytes | None:
if value is not None:
# Handle both raw dicts and SensitiveValue wrappers
if isinstance(value, SensitiveValue):
# Get raw value for storage
value = value.get_value(apply_mask=False)
json_str = json.dumps(value)
return encrypt_string_to_bytes(json_str)
@@ -214,40 +184,14 @@ class EncryptedJson(_EncryptedBase):
)
return None
_REGISTERED_ATTRS: set[str] = set()
@event.listens_for(Mapper, "mapper_configured")
def _register_sensitive_value_set_events(
mapper: Mapper,
class_: type,
) -> None:
"""Auto-wrap raw values in SensitiveValue when assigned to encrypted columns."""
for prop in mapper.column_attrs:
for col in prop.columns:
if isinstance(col.type, _EncryptedBase):
col_type = col.type
attr = getattr(class_, prop.key)
# Guard against double-registration (e.g. if mapper is
# re-configured in test setups)
attr_key = f"{class_.__qualname__}.{prop.key}"
if attr_key in _REGISTERED_ATTRS:
continue
_REGISTERED_ATTRS.add(attr_key)
@event.listens_for(attr, "set", retval=True)
def _wrap_value(
target: Any, # noqa: ARG001
value: Any,
oldvalue: Any, # noqa: ARG001
initiator: Any, # noqa: ARG001
_col_type: _EncryptedBase = col_type,
) -> Any:
if value is not None and not isinstance(value, SensitiveValue):
return _col_type.wrap_raw(value)
return value
def compare_values(self, x: Any, y: Any) -> bool:
if x is None or y is None:
return x == y
if isinstance(x, SensitiveValue):
x = x.get_value(apply_mask=False)
if isinstance(y, SensitiveValue):
y = y.get_value(apply_mask=False)
return x == y
class NullFilteredString(TypeDecorator):
@@ -2426,38 +2370,6 @@ class SyncRecord(Base):
)
class HierarchyNodeByConnectorCredentialPair(Base):
"""Tracks which cc_pairs reference each hierarchy node.
During pruning, stale entries are removed for the current cc_pair.
Hierarchy nodes with zero remaining entries are then deleted.
"""
__tablename__ = "hierarchy_node_by_connector_credential_pair"
hierarchy_node_id: Mapped[int] = mapped_column(
ForeignKey("hierarchy_node.id", ondelete="CASCADE"), primary_key=True
)
connector_id: Mapped[int] = mapped_column(primary_key=True)
credential_id: Mapped[int] = mapped_column(primary_key=True)
__table_args__ = (
ForeignKeyConstraint(
["connector_id", "credential_id"],
[
"connector_credential_pair.connector_id",
"connector_credential_pair.credential_id",
],
ondelete="CASCADE",
),
Index(
"ix_hierarchy_node_cc_pair_connector_credential",
"connector_id",
"credential_id",
),
)
class DocumentByConnectorCredentialPair(Base):
"""Represents an indexing of a document by a specific connector / credential pair"""

View File

@@ -205,9 +205,7 @@ def update_persona_access(
NOTE: Callers are responsible for committing."""
needs_sync = False
if is_public is not None:
needs_sync = True
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
if persona:
persona.is_public = is_public
@@ -215,7 +213,6 @@ def update_persona_access(
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
# and a non-empty list means "replace with these shares".
if user_ids is not None:
needs_sync = True
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
).delete(synchronize_session="fetch")
@@ -236,7 +233,6 @@ def update_persona_access(
# MIT doesn't support group-based sharing, so we allow clearing (no-op since
# there shouldn't be any) but raise an error if trying to add actual groups.
if group_ids is not None:
needs_sync = True
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.persona_id == persona_id
).delete(synchronize_session="fetch")
@@ -244,10 +240,6 @@ def update_persona_access(
if group_ids:
raise NotImplementedError("Onyx MIT does not support group-based sharing")
# When sharing changes, user file ACLs need to be updated in the vector DB
if needs_sync:
mark_persona_user_files_for_sync(persona_id, db_session)
def create_update_persona(
persona_id: int | None,
@@ -859,24 +851,6 @@ def update_personas_display_priority(
db_session.commit()
def mark_persona_user_files_for_sync(
persona_id: int,
db_session: Session,
) -> None:
"""When persona sharing changes, mark all of its user files for sync
so that their ACLs get updated in the vector DB."""
persona = (
db_session.query(Persona)
.options(selectinload(Persona.user_files))
.filter(Persona.id == persona_id)
.first()
)
if not persona:
return
file_ids = [uf.id for uf in persona.user_files]
_mark_files_need_persona_sync(db_session, file_ids)
def _mark_files_need_persona_sync(
db_session: Session,
user_file_ids: list[UUID],

View File

@@ -1,161 +0,0 @@
"""Rotate encryption key for all encrypted columns.
Dynamically discovers all columns using EncryptedString / EncryptedJson,
decrypts each value with the old key, and re-encrypts with the current
ENCRYPTION_KEY_SECRET.
The operation is idempotent: rows already encrypted with the current key
are skipped. Commits are made in batches so a crash mid-rotation can be
safely resumed by re-running.
"""
import json
from typing import Any
from sqlalchemy import LargeBinary
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
from onyx.db.models import Base
from onyx.db.models import EncryptedJson
from onyx.db.models import EncryptedString
from onyx.utils.encryption import decrypt_bytes_to_string
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
logger = setup_logger()
_BATCH_SIZE = 500
def _can_decrypt_with_current_key(data: bytes) -> bool:
"""Check if data is already encrypted with the current key.
Passes the key explicitly so the fallback-to-raw-decode path in
_decrypt_bytes is NOT triggered — a clean success/failure signal.
"""
try:
decrypt_bytes_to_string(data, key=ENCRYPTION_KEY_SECRET)
return True
except Exception:
return False
def _discover_encrypted_columns() -> list[tuple[type, str, list[str], bool]]:
"""Walk all ORM models and find columns using EncryptedString/EncryptedJson.
Returns list of (ModelClass, column_attr_name, [pk_attr_names], is_json).
"""
results: list[tuple[type, str, list[str], bool]] = []
for mapper in Base.registry.mappers:
model_cls = mapper.class_
pk_names = [col.key for col in mapper.primary_key]
for prop in mapper.column_attrs:
for col in prop.columns:
if isinstance(col.type, EncryptedJson):
results.append((model_cls, prop.key, pk_names, True))
elif isinstance(col.type, EncryptedString):
results.append((model_cls, prop.key, pk_names, False))
return results
def rotate_encryption_key(
db_session: Session,
old_key: str | None,
dry_run: bool = False,
) -> dict[str, int]:
"""Decrypt all encrypted columns with old_key and re-encrypt with the current key.
Args:
db_session: Active database session.
old_key: The previous encryption key. Pass None or "" if values were
not previously encrypted with a key.
dry_run: If True, count rows that need rotation without modifying data.
Returns:
Dict of "table.column" -> number of rows re-encrypted (or would be).
Commits every _BATCH_SIZE rows so that locks are held briefly and progress
is preserved on crash. Already-rotated rows are detected and skipped,
making the operation safe to re-run.
"""
if not global_version.is_ee_version():
raise RuntimeError("EE mode is not enabled — rotation requires EE encryption.")
if not ENCRYPTION_KEY_SECRET:
raise RuntimeError(
"ENCRYPTION_KEY_SECRET is not set — cannot rotate. "
"Set the target encryption key in the environment before running."
)
encrypted_columns = _discover_encrypted_columns()
totals: dict[str, int] = {}
for model_cls, col_name, pk_names, is_json in encrypted_columns:
table_name: str = model_cls.__tablename__ # type: ignore[attr-defined]
col_attr = getattr(model_cls, col_name)
pk_attrs = [getattr(model_cls, pk) for pk in pk_names]
# Read raw bytes directly, bypassing the TypeDecorator
raw_col = col_attr.property.columns[0]
stmt = select(*pk_attrs, raw_col.cast(LargeBinary)).where(col_attr.is_not(None))
rows = db_session.execute(stmt).all()
reencrypted = 0
batch_pending = 0
for row in rows:
raw_bytes: bytes | None = row[-1]
if raw_bytes is None:
continue
if _can_decrypt_with_current_key(raw_bytes):
continue
try:
if not old_key:
decrypted_str = raw_bytes.decode("utf-8")
else:
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=old_key)
# For EncryptedJson, parse back to dict so the TypeDecorator
# can json.dumps() it cleanly (avoids double-encoding).
value: Any = json.loads(decrypted_str) if is_json else decrypted_str
except (ValueError, UnicodeDecodeError) as e:
pk_vals = [row[i] for i in range(len(pk_names))]
logger.warning(
f"Could not decrypt/parse {table_name}.{col_name} "
f"row {pk_vals} — skipping: {e}"
)
continue
if not dry_run:
pk_filters = [pk_attr == row[i] for i, pk_attr in enumerate(pk_attrs)]
update_stmt = (
update(model_cls).where(*pk_filters).values({col_name: value})
)
db_session.execute(update_stmt)
batch_pending += 1
if batch_pending >= _BATCH_SIZE:
db_session.commit()
batch_pending = 0
reencrypted += 1
# Flush remaining rows in this column
if batch_pending > 0:
db_session.commit()
if reencrypted > 0:
totals[f"{table_name}.{col_name}"] = reencrypted
logger.info(
f"{'[DRY RUN] Would re-encrypt' if dry_run else 'Re-encrypted'} "
f"{reencrypted} value(s) in {table_name}.{col_name}"
)
return totals

View File

@@ -3,11 +3,9 @@ from uuid import UUID
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.db.models import Persona
from onyx.db.models import Project__UserFile
from onyx.db.models import UserFile
@@ -120,31 +118,3 @@ def get_file_ids_by_user_file_ids(
) -> list[str]:
user_files = db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).all()
return [user_file.file_id for user_file in user_files]
def fetch_user_files_with_access_relationships(
user_file_ids: list[str],
db_session: Session,
eager_load_groups: bool = False,
) -> list[UserFile]:
"""Fetch user files with the owner and assistant relationships
eagerly loaded (needed for computing access control).
When eager_load_groups is True, Persona.groups is also loaded so that
callers can extract user-group names without a second DB round-trip."""
persona_sub_options = [
selectinload(Persona.users),
selectinload(Persona.user),
]
if eager_load_groups:
persona_sub_options.append(selectinload(Persona.groups))
return (
db_session.query(UserFile)
.options(
joinedload(UserFile.user),
selectinload(UserFile.assistants).options(*persona_sub_options),
)
.filter(UserFile.id.in_(user_file_ids))
.all()
)

View File

@@ -1,4 +1,3 @@
import csv
import gc
import io
import json
@@ -20,7 +19,6 @@ from zipfile import BadZipFile
import chardet
import openpyxl
from openpyxl.worksheet.worksheet import Worksheet
from PIL import Image
from onyx.configs.constants import ONYX_METADATA_FILENAME
@@ -354,65 +352,6 @@ def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
return presentation.markdown
def _worksheet_to_matrix(
worksheet: Worksheet,
) -> list[list[str]]:
"""
Converts a singular worksheet to a matrix of values
"""
rows: list[list[str]] = []
for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True):
row = ["" if cell is None else str(cell) for cell in worksheet_row]
rows.append(row)
return rows
def _clean_worksheet_matrix(matrix: list[list[str]]) -> list[list[str]]:
"""
Cleans a worksheet matrix by removing rows if there are N consecutive empty
rows and removing cols if there are M consecutive empty columns
"""
MAX_EMPTY_ROWS = 2 # Runs longer than this are capped to max_empty; shorter runs are preserved as-is
MAX_EMPTY_COLS = 2
# Row cleanup
matrix = _remove_empty_runs(matrix, max_empty=MAX_EMPTY_ROWS)
# Column cleanup (transpose, clean, transpose back)
transposed = list(map(list, zip(*matrix))) if matrix else []
transposed = _remove_empty_runs(transposed, max_empty=MAX_EMPTY_COLS)
matrix = list(map(list, zip(*transposed))) if transposed else []
return matrix
def _remove_empty_runs(
rows: list[list[str]],
max_empty: int,
) -> list[list[str]]:
"""Removes entire runs of empty rows when the run length exceeds max_empty.
Leading and trailing empty rows are always dropped regardless of run length,
since there is no adjacent non-empty row to bound the run.
"""
result: list[list[str]] = []
empty_buffer: list[list[str]] = []
for row in rows:
# Check if empty
if not any(row):
empty_buffer.append(row)
else:
# Add upto max empty rows onto the result - that's what we allow
result.extend(empty_buffer[:max_empty])
# Add the new non-empty row
result.append(row)
empty_buffer = []
return result
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
# TODO: switch back to this approach in a few months when markitdown
# fixes their handling of excel files
@@ -451,15 +390,30 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
)
return ""
raise
raise e
text_content = []
for sheet in workbook.worksheets:
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
buf = io.StringIO()
writer = csv.writer(buf, lineterminator="\n")
writer.writerows(sheet_matrix)
text_content.append(buf.getvalue().rstrip("\n"))
rows = []
num_empty_consecutive_rows = 0
for row in sheet.iter_rows(min_row=1, values_only=True):
row_str = ",".join(str(cell or "") for cell in row)
# Only add the row if there are any values in the cells
if len(row_str) >= len(row):
rows.append(row_str)
num_empty_consecutive_rows = 0
else:
num_empty_consecutive_rows += 1
if num_empty_consecutive_rows > 100:
# handle massive excel sheets with mostly empty cells
logger.warning(
f"Found {num_empty_consecutive_rows} empty rows in {file_name}, skipping rest of file"
)
break
sheet_str = "\n".join(rows)
text_content.append(sheet_str)
return TEXT_SECTION_SEPARATOR.join(text_content)

View File

@@ -19,16 +19,12 @@ class OnyxMimeTypes:
PLAIN_TEXT_MIME_TYPE,
"text/markdown",
"text/x-markdown",
"text/x-log",
"text/x-config",
"text/tab-separated-values",
"application/json",
"application/xml",
"text/xml",
"application/x-yaml",
"application/yaml",
"text/yaml",
"text/x-yaml",
}
DOCUMENT_MIME_TYPES = {
PDF_MIME_TYPE,

View File

@@ -1,5 +1,4 @@
import abc
from typing import cast
from onyx.utils.special_types import JSON_ro
@@ -8,19 +7,6 @@ class KvKeyNotFoundError(Exception):
pass
def unwrap_str(val: JSON_ro) -> str:
"""Unwrap a string stored as {"value": str} in the encrypted KV store.
Also handles legacy plain-string values cached in Redis."""
if isinstance(val, dict):
try:
return cast(str, val["value"])
except KeyError:
raise ValueError(
f"Expected dict with 'value' key, got keys: {list(val.keys())}"
)
return cast(str, val)
class KeyValueStore:
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
# It's read from the global thread level variable

View File

@@ -43,6 +43,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
LlmProviderNames.AZURE,
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
LlmProviderNames.LITELLM_PROXY,
]
@@ -59,6 +60,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
"ollama": "Ollama",
LlmProviderNames.OLLAMA_CHAT: "Ollama",
LlmProviderNames.LM_STUDIO: "LM Studio",
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
"groq": "Groq",
"anyscale": "Anyscale",
"deepseek": "DeepSeek",
@@ -109,6 +111,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
LlmProviderNames.LM_STUDIO,
LlmProviderNames.VERTEX_AI,
LlmProviderNames.AZURE,
LlmProviderNames.LITELLM_PROXY,
}
# Model family name mappings for display name generation

View File

@@ -11,6 +11,8 @@ OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
LM_STUDIO_PROVIDER_NAME = "lm_studio"
LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
# Providers that use optional Bearer auth from custom_config
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,

View File

@@ -15,6 +15,7 @@ from onyx.llm.well_known_providers.auto_update_service import (
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
@@ -47,6 +48,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
OLLAMA_PROVIDER_NAME: [], # Dynamic - fetched from Ollama API
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
}
@@ -331,6 +333,7 @@ def get_provider_display_name(provider_name: str) -> str:
BEDROCK_PROVIDER_NAME: "Amazon Bedrock",
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
OPENROUTER_PROVIDER_NAME: "OpenRouter",
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
}
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:

View File

@@ -1,5 +1,9 @@
import re
from enum import Enum
# Matches Slack channel references like <#C097NBWMY8Y> or <#C097NBWMY8Y|channel-name>
SLACK_CHANNEL_REF_PATTERN = re.compile(r"<#([A-Z0-9]+)(?:\|([^>]+))?>")
LIKE_BLOCK_ACTION_ID = "feedback-like"
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
SHOW_EVERYONE_ACTION_ID = "show-everyone"

View File

@@ -18,15 +18,18 @@ from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS
from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import Tag
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import SlackChannelConfig
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.db.users import get_user_by_email
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.models import ThreadMessage
from onyx.onyxbot.slack.utils import get_channel_from_id
from onyx.onyxbot.slack.utils import get_channel_name_from_id
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import SlackRateLimiter
@@ -41,6 +44,51 @@ srl = SlackRateLimiter()
RT = TypeVar("RT") # return type
def resolve_channel_references(
message: str,
client: WebClient,
logger: OnyxLoggingAdapter,
) -> tuple[str, list[Tag]]:
"""Parse Slack channel references from a message, resolve IDs to names,
replace the raw markup with readable #channel-name, and return channel tags
for search filtering."""
tags: list[Tag] = []
channel_matches = SLACK_CHANNEL_REF_PATTERN.findall(message)
seen_channel_ids: set[str] = set()
for channel_id, channel_name_from_markup in channel_matches:
if channel_id in seen_channel_ids:
continue
seen_channel_ids.add(channel_id)
channel_name = channel_name_from_markup or None
if not channel_name:
try:
channel_info = get_channel_from_id(client=client, channel_id=channel_id)
channel_name = channel_info.get("name") or None
except Exception:
logger.warning(f"Failed to resolve channel name for ID: {channel_id}")
if not channel_name:
continue
# Replace raw Slack markup with readable channel name
if channel_name_from_markup:
message = message.replace(
f"<#{channel_id}|{channel_name_from_markup}>",
f"#{channel_name}",
)
else:
message = message.replace(
f"<#{channel_id}>",
f"#{channel_name}",
)
tags.append(Tag(tag_key="Channel", tag_value=channel_name))
return message, tags
def rate_limits(
client: WebClient, channel: str, thread_ts: Optional[str]
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
@@ -157,6 +205,20 @@ def handle_regular_answer(
user_message = messages[-1]
history_messages = messages[:-1]
# Resolve any <#CHANNEL_ID> references in the user message to readable
# channel names and extract channel tags for search filtering
resolved_message, channel_tags = resolve_channel_references(
message=user_message.message,
client=client,
logger=logger,
)
user_message = ThreadMessage(
message=resolved_message,
sender=user_message.sender,
role=user_message.role,
)
channel_name, _ = get_channel_name_from_id(
client=client,
channel_id=channel,
@@ -207,6 +269,7 @@ def handle_regular_answer(
source_type=None,
document_set=document_set_names,
time_cutoff=None,
tags=channel_tags if channel_tags else None,
)
new_message_request = SendMessageRequest(
@@ -231,6 +294,16 @@ def handle_regular_answer(
slack_context_str=slack_context_str,
)
# If a channel filter was applied but no results were found, override
# the LLM response to avoid hallucinated answers about unindexed channels
if channel_tags and not answer.citation_info and not answer.top_documents:
channel_names = ", ".join(f"#{tag.tag_value}" for tag in channel_tags)
answer.answer = (
f"No indexed data found for {channel_names}. "
"This channel may not be indexed, or there may be no messages "
"matching your query within it."
)
except Exception as e:
logger.exception(
f"Unable to process message - did not successfully answer "
@@ -285,6 +358,7 @@ def handle_regular_answer(
only_respond_if_citations
and not answer.citation_info
and not message_info.bypass_filters
and not channel_tags
):
logger.error(
f"Unable to find citations to answer: '{answer.answer}' - not answering!"

View File

@@ -16,7 +16,6 @@ Cache Strategy:
using only the SOURCE-type node as the ancestor
"""
from typing import cast
from typing import TYPE_CHECKING
from pydantic import BaseModel
@@ -205,30 +204,6 @@ def cache_hierarchy_nodes_batch(
redis_client.expire(raw_id_key, HIERARCHY_CACHE_TTL_SECONDS)
def evict_hierarchy_nodes_from_cache(
redis_client: Redis,
source: DocumentSource,
raw_node_ids: list[str],
) -> None:
"""Remove specific hierarchy nodes from the Redis cache.
Deletes entries from both the parent-chain hash and the raw_id→node_id hash.
"""
if not raw_node_ids:
return
cache_key = _cache_key(source)
raw_id_key = _raw_id_cache_key(source)
# Look up node_ids so we can remove them from the parent-chain hash
raw_values = cast(list[str | None], redis_client.hmget(raw_id_key, raw_node_ids))
node_id_strs = [v for v in raw_values if v is not None]
if node_id_strs:
redis_client.hdel(cache_key, *node_id_strs)
redis_client.hdel(raw_id_key, *raw_node_ids)
def get_node_id_from_raw_id(
redis_client: Redis,
source: DocumentSource,

View File

@@ -1905,7 +1905,7 @@ def get_connector_by_id(
@router.post("/connector-request")
def submit_connector_request(
request_data: ConnectorRequestSubmission,
user: User = Depends(current_user),
user: User | None = Depends(current_user),
) -> StatusResponse:
"""
Submit a connector request for Cloud deployments.
@@ -1918,7 +1918,7 @@ def submit_connector_request(
raise HTTPException(status_code=400, detail="Connector name cannot be empty")
# Get user identifier for telemetry
user_email = user.email
user_email = user.email if user else None
distinct_id = user_email or tenant_id
# Track connector request via PostHog telemetry (Cloud only)

View File

@@ -57,6 +57,9 @@ def list_messages(
db_session: Session = Depends(get_session),
) -> MessageListResponse:
"""Get all messages for a build session."""
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
session_manager = SessionManager(db_session)
messages = session_manager.list_messages(session_id, user.id)

View File

@@ -54,14 +54,18 @@ def _require_opensearch(db_session: Session) -> None:
)
def _get_user_access_info(user: User, db_session: Session) -> tuple[str, list[str]]:
def _get_user_access_info(
user: User | None, db_session: Session
) -> tuple[str | None, list[str]]:
if not user:
return None, []
return user.email, get_user_external_group_ids(db_session, user)
@router.get(HIERARCHY_NODES_LIST_PATH)
def list_accessible_hierarchy_nodes(
source: DocumentSource,
user: User = Depends(current_user),
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> HierarchyNodesResponse:
_require_opensearch(db_session)
@@ -88,7 +92,7 @@ def list_accessible_hierarchy_nodes(
@router.post(HIERARCHY_NODE_DOCUMENTS_PATH)
def list_accessible_hierarchy_node_documents(
documents_request: HierarchyNodeDocumentsRequest,
user: User = Depends(current_user),
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> HierarchyNodeDocumentsResponse:
_require_opensearch(db_session)

View File

@@ -1013,7 +1013,7 @@ def get_mcp_servers_for_assistant(
@router.get("/servers", response_model=MCPServersResponse)
def get_mcp_servers_for_user(
db: Session = Depends(get_session),
user: User = Depends(current_user),
user: User | None = Depends(current_user),
) -> MCPServersResponse:
"""List all MCP servers for use in agent configuration and chat UI.

View File

@@ -10,8 +10,6 @@ from pydantic import Field
from sqlalchemy.orm import Session
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
@@ -37,38 +35,6 @@ def get_safe_filename(upload: UploadFile) -> str:
return upload.filename
def get_upload_size_bytes(upload: UploadFile) -> int | None:
"""Best-effort file size in bytes without consuming the stream."""
if upload.size is not None:
return upload.size
try:
current_pos = upload.file.tell()
upload.file.seek(0, 2)
size = upload.file.tell()
upload.file.seek(current_pos)
return size
except Exception as e:
logger.warning(
"Could not determine upload size via stream seek "
f"(filename='{get_safe_filename(upload)}', "
f"error_type={type(e).__name__}, error={e})"
)
return None
def is_upload_too_large(upload: UploadFile, max_bytes: int) -> bool:
"""Return True when upload size is known and exceeds max_bytes."""
size_bytes = get_upload_size_bytes(upload)
if size_bytes is None:
logger.warning(
"Could not determine upload size; skipping size-limit check for "
f"'{get_safe_filename(upload)}'"
)
return False
return size_bytes > max_bytes
# Guard against extremely large images
Image.MAX_IMAGE_PIXELS = 12000 * 12000
@@ -193,18 +159,6 @@ def categorize_uploaded_files(
for upload in files:
try:
filename = get_safe_filename(upload)
# Size limit is a hard safety cap and is enforced even when token
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
)
)
continue
extension = get_file_ext(filename)
# If image, estimate tokens via dedicated method first

View File

@@ -58,6 +58,9 @@ from onyx.llm.well_known_providers.llm_provider_options import (
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LitellmFinalModelResponse
from onyx.server.manage.llm.models import LitellmModelDetails
from onyx.server.manage.llm.models import LitellmModelsRequest
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderResponse
@@ -72,6 +75,7 @@ from onyx.server.manage.llm.models import OllamaModelsRequest
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
from onyx.server.manage.llm.models import OpenRouterModelDetails
from onyx.server.manage.llm.models import OpenRouterModelsRequest
from onyx.server.manage.llm.models import SyncModelEntry
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.server.manage.llm.utils import generate_bedrock_display_name
@@ -98,6 +102,34 @@ def _mask_string(value: str) -> str:
return value[:4] + "****" + value[-4:]
def _sync_fetched_models(
db_session: Session,
provider_name: str,
models: list[SyncModelEntry],
source_label: str,
) -> None:
"""Sync fetched models to DB for the given provider.
Args:
db_session: Database session
provider_name: Name of the LLM provider
models: List of SyncModelEntry objects describing the fetched models
source_label: Human-readable label for log messages (e.g. "Bedrock", "LiteLLM")
"""
try:
new_count = sync_model_configurations(
db_session=db_session,
provider_name=provider_name,
models=models,
)
if new_count > 0:
logger.info(
f"Added {new_count} new {source_label} models to provider '{provider_name}'"
)
except ValueError as e:
logger.warning(f"Failed to sync {source_label} models to DB: {e}")
# Keys in custom_config that contain sensitive credentials
_SENSITIVE_CONFIG_KEYS = {
"vertex_credentials",
@@ -963,27 +995,20 @@ def get_bedrock_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
try:
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in results
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new Bedrock models to provider '{request.provider_name}'"
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
except ValueError as e:
logger.warning(f"Failed to sync Bedrock models to DB: {e}")
for r in results
],
source_label="Bedrock",
)
return results
@@ -1101,27 +1126,20 @@ def get_ollama_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
try:
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in sorted_results
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new Ollama models to provider '{request.provider_name}'"
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
except ValueError as e:
logger.warning(f"Failed to sync Ollama models to DB: {e}")
for r in sorted_results
],
source_label="Ollama",
)
return sorted_results
@@ -1210,27 +1228,20 @@ def get_openrouter_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
try:
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in sorted_results
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new OpenRouter models to provider '{request.provider_name}'"
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
except ValueError as e:
logger.warning(f"Failed to sync OpenRouter models to DB: {e}")
for r in sorted_results
],
source_label="OpenRouter",
)
return sorted_results
@@ -1324,26 +1335,119 @@ def get_lm_studio_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
try:
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in sorted_results
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new LM Studio models to provider '{request.provider_name}'"
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
except ValueError as e:
logger.warning(f"Failed to sync LM Studio models to DB: {e}")
for r in sorted_results
],
source_label="LM Studio",
)
return sorted_results
@admin_router.post("/litellm/available-models")
def get_litellm_available_models(
request: LitellmModelsRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LitellmFinalModelResponse]:
"""Fetch available models from Litellm proxy /v1/models endpoint."""
response_json = _get_litellm_models_response(
api_key=request.api_key, api_base=request.api_base
)
models = response_json.get("data", [])
if not isinstance(models, list) or len(models) == 0:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your Litellm endpoint",
)
results: list[LitellmFinalModelResponse] = []
for model in models:
try:
model_details = LitellmModelDetails.model_validate(model)
results.append(
LitellmFinalModelResponse(
provider_name=model_details.owned_by,
model_name=model_details.id,
)
)
except Exception as e:
logger.warning(
"Failed to parse Litellm model entry",
extra={"error": str(e), "item": str(model)[:1000]},
)
if not results:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No compatible models found from Litellm",
)
sorted_results = sorted(results, key=lambda m: m.model_name.lower())
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.model_name,
display_name=r.model_name,
)
for r in sorted_results
],
source_label="LiteLLM",
)
return sorted_results
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
url = f"{cleaned_api_base}/v1/models"
headers = {
"Authorization": f"Bearer {api_key}",
"HTTP-Referer": "https://onyx.app",
"X-Title": "Onyx",
}
try:
response = httpx.get(url, headers=headers, timeout=10.0)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
)
elif e.response.status_code == 404:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"LiteLLM models endpoint not found at {url}. "
"Please verify the API base URL.",
)
else:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch LiteLLM models: {e}",
)
except Exception as e:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch LiteLLM models: {e}",
)

View File

@@ -420,3 +420,32 @@ class LLMProviderResponse(BaseModel, Generic[T]):
default_text=default_text,
default_vision=default_vision,
)
class SyncModelEntry(BaseModel):
"""Typed model for syncing fetched models to the DB."""
name: str
display_name: str
max_input_tokens: int | None = None
supports_image_input: bool = False
class LitellmModelsRequest(BaseModel):
api_key: str
api_base: str
provider_name: str | None = None # Optional: to save models to existing provider
class LitellmModelDetails(BaseModel):
"""Response model for Litellm proxy /api/v1/models endpoint"""
id: str # Model ID (e.g. "gpt-4o")
object: str # "model"
created: int # Unix timestamp in seconds
owned_by: str # Provider name (e.g. "openai")
class LitellmFinalModelResponse(BaseModel):
provider_name: str # Provider name (e.g. "openai")
model_name: str # Model ID (e.g. "gpt-4o")

View File

@@ -42,7 +42,6 @@ class StreamingType(Enum):
REASONING_DONE = "reasoning_done"
CITATION_INFO = "citation_info"
TOOL_CALL_DEBUG = "tool_call_debug"
TOOL_CALL_ARGUMENT_DELTA = "tool_call_argument_delta"
MEMORY_TOOL_START = "memory_tool_start"
MEMORY_TOOL_DELTA = "memory_tool_delta"
@@ -277,15 +276,6 @@ class CustomToolDelta(BaseObj):
error: CustomToolErrorInfo | None = None
class ToolCallArgumentDelta(BaseObj):
type: Literal["tool_call_argument_delta"] = (
StreamingType.TOOL_CALL_ARGUMENT_DELTA.value
)
tool_type: str
argument_deltas: dict[str, Any]
################################################
# File Reader Packets
################################################
@@ -407,7 +397,6 @@ PacketObj = Union[
# Citation Packets
CitationInfo,
ToolCallDebug,
ToolCallArgumentDelta,
# Deep Research Packets
DeepResearchPlanStart,
DeepResearchPlanDelta,

View File

@@ -78,7 +78,6 @@ class Settings(BaseModel):
# User Knowledge settings
user_knowledge_enabled: bool | None = True
user_file_max_upload_size_mb: int | None = None
# Connector settings
show_extra_connectors: bool | None = True

View File

@@ -3,7 +3,6 @@ from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.constants import KV_SETTINGS_KEY
from onyx.configs.constants import OnyxRedisLocks
from onyx.key_value_store.factory import get_kv_store
@@ -51,7 +50,6 @@ def load_settings() -> Settings:
if DISABLE_USER_KNOWLEDGE:
settings.user_knowledge_enabled = False
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
return settings

View File

@@ -56,23 +56,3 @@ def get_built_in_tool_ids() -> list[str]:
def get_built_in_tool_by_id(in_code_tool_id: str) -> Type[BUILT_IN_TOOL_TYPES]:
return BUILT_IN_TOOL_MAP[in_code_tool_id]
def _build_tool_name_to_class() -> dict[str, Type[BUILT_IN_TOOL_TYPES]]:
"""Build a mapping from LLM-facing tool name to tool class."""
result: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {}
for cls in BUILT_IN_TOOL_MAP.values():
name_attr = cls.__dict__.get("name")
if isinstance(name_attr, property) and name_attr.fget is not None:
tool_name = name_attr.fget(cls)
elif isinstance(name_attr, str):
tool_name = name_attr
else:
raise ValueError(
f"Built-in tool {cls.__name__} must define a valid LLM-facing tool name"
)
result[tool_name] = cls
return result
TOOL_NAME_TO_CLASS: dict[str, Type[BUILT_IN_TOOL_TYPES]] = _build_tool_name_to_class()

View File

@@ -92,7 +92,3 @@ class Tool(abc.ABC, Generic[TOverride]):
**llm_kwargs: Any,
) -> ToolResponse:
raise NotImplementedError
@classmethod
def should_emit_argument_deltas(cls) -> bool:
return False

View File

@@ -376,8 +376,3 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
rich_response=None,
llm_facing_response=llm_response,
)
@classmethod
@override
def should_emit_argument_deltas(cls) -> bool:
return True

View File

@@ -11,20 +11,16 @@ logger = setup_logger()
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
def _encrypt_string(input_str: str, key: str | None = None) -> bytes: # noqa: ARG001
def _encrypt_string(input_str: str) -> bytes:
if ENCRYPTION_KEY_SECRET:
logger.warning("MIT version of Onyx does not support encryption of secrets.")
elif key is not None:
logger.debug("MIT encrypt called with explicit key — key ignored.")
return input_str.encode()
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str: # noqa: ARG001
if ENCRYPTION_KEY_SECRET:
logger.warning("MIT version of Onyx does not support decryption of secrets.")
elif key is not None:
logger.debug("MIT decrypt called with explicit key — key ignored.")
def _decrypt_bytes(input_bytes: bytes) -> str:
# No need to double warn. If you wish to learn more about encryption features
# refer to the Onyx EE code
return input_bytes.decode()
@@ -90,15 +86,15 @@ def _mask_list(items: list[Any]) -> list[Any]:
return masked
def encrypt_string_to_bytes(intput_str: str, key: str | None = None) -> bytes:
def encrypt_string_to_bytes(intput_str: str) -> bytes:
versioned_encryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_encrypt_string"
)
return versioned_encryption_fn(intput_str, key=key)
return versioned_encryption_fn(intput_str)
def decrypt_bytes_to_string(intput_bytes: bytes, key: str | None = None) -> str:
def decrypt_bytes_to_string(intput_bytes: bytes) -> str:
versioned_decryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_decrypt_bytes"
)
return versioned_decryption_fn(intput_bytes, key=key)
return versioned_decryption_fn(intput_bytes)

View File

@@ -128,8 +128,6 @@ class SensitiveValue(Generic[T]):
value = self._decrypt()
if not apply_mask:
# Callers must not mutate the returned dict — doing so would
# desync the cache from the encrypted bytes and the DB.
return value
# Apply masking
@@ -176,20 +174,18 @@ class SensitiveValue(Generic[T]):
)
def __eq__(self, other: Any) -> bool:
"""Compare SensitiveValues by their decrypted content."""
# NOTE: if you attempt to compare a string/dict to a SensitiveValue,
# this comparison will return NotImplemented, which then evaluates to False.
# This is the convention and required for SQLAlchemy's attribute tracking.
if not isinstance(other, SensitiveValue):
return NotImplemented
return self._decrypt() == other._decrypt()
"""Prevent direct comparison which might expose value."""
if isinstance(other, SensitiveValue):
# Compare encrypted bytes for equality check
return self._encrypted_bytes == other._encrypted_bytes
raise SensitiveAccessError(
"Cannot compare SensitiveValue with non-SensitiveValue. "
"Use .get_value(apply_mask=True/False) to access the value for comparison."
)
def __hash__(self) -> int:
"""Hash based on decrypted content."""
value = self._decrypt()
if isinstance(value, dict):
return hash(json.dumps(value, sort_keys=True))
return hash(value)
"""Allow hashing based on encrypted bytes."""
return hash(self._encrypted_bytes)
# Prevent JSON serialization
def __json__(self) -> Any:

View File

@@ -2,6 +2,7 @@ import contextvars
import threading
import uuid
from enum import Enum
from typing import cast
import requests
@@ -14,7 +15,6 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import User
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.key_value_store.interface import unwrap_str
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
@@ -25,7 +25,6 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
_DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.onyx.app/anonymous_telemetry"
_CACHED_UUID: str | None = None
_CACHED_INSTANCE_DOMAIN: str | None = None
@@ -63,10 +62,10 @@ def get_or_generate_uuid() -> str:
kv_store = get_kv_store()
try:
_CACHED_UUID = unwrap_str(kv_store.load(KV_CUSTOMER_UUID_KEY))
_CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY))
except KvKeyNotFoundError:
_CACHED_UUID = str(uuid.uuid4())
kv_store.store(KV_CUSTOMER_UUID_KEY, {"value": _CACHED_UUID}, encrypt=True)
kv_store.store(KV_CUSTOMER_UUID_KEY, _CACHED_UUID, encrypt=True)
return _CACHED_UUID
@@ -80,16 +79,14 @@ def _get_or_generate_instance_domain() -> str | None: #
kv_store = get_kv_store()
try:
_CACHED_INSTANCE_DOMAIN = unwrap_str(kv_store.load(KV_INSTANCE_DOMAIN_KEY))
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
except KvKeyNotFoundError:
with get_session_with_current_tenant() as db_session:
first_user = db_session.query(User).first()
if first_user:
_CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1]
kv_store.store(
KV_INSTANCE_DOMAIN_KEY,
{"value": _CACHED_INSTANCE_DOMAIN},
encrypt=True,
KV_INSTANCE_DOMAIN_KEY, _CACHED_INSTANCE_DOMAIN, encrypt=True
)
return _CACHED_INSTANCE_DOMAIN

View File

@@ -263,7 +263,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.6.3
onyx-devtools==0.7.0
# via onyx
openai==2.14.0
# via
@@ -406,7 +406,7 @@ referencing==0.36.2
# jsonschema-specifications
regex==2025.11.3
# via tiktoken
release-tag==0.4.3
release-tag==0.5.2
# via onyx
reorder-python-imports-black==3.14.0
# via onyx

View File

@@ -1,93 +1,48 @@
"""Decrypt a raw hex-encoded credential value.
Usage:
python -m scripts.decrypt <hex_value>
python -m scripts.decrypt <hex_value> --key "my-encryption-key"
python -m scripts.decrypt <hex_value> --key ""
Pass --key "" to skip decryption and just decode the raw bytes as UTF-8.
Omit --key to use the current ENCRYPTION_KEY_SECRET from the environment.
"""
import argparse
import binascii
import json
import os
import sys
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from onyx.utils.encryption import decrypt_bytes_to_string # noqa: E402
from onyx.utils.variable_functionality import global_version # noqa: E402
from onyx.utils.encryption import decrypt_bytes_to_string
def decrypt_raw_credential(encrypted_value: str, key: str | None = None) -> None:
"""Decrypt and display a raw encrypted credential value.
def decrypt_raw_credential(encrypted_value: str) -> None:
"""Decrypt and display a raw encrypted credential value
Args:
encrypted_value: The hex-encoded encrypted credential value.
key: Encryption key to use. None means use ENCRYPTION_KEY_SECRET,
empty string means just decode as UTF-8.
encrypted_value: The hex encoded encrypted credential value
"""
# Strip common hex prefixes
if encrypted_value.startswith("\\x"):
encrypted_value = encrypted_value[2:]
elif encrypted_value.startswith("x"):
encrypted_value = encrypted_value[1:]
print(encrypted_value)
try:
raw_bytes = binascii.unhexlify(encrypted_value)
# If string starts with 'x', remove it as it's just a prefix indicating hex
if encrypted_value.startswith("x"):
encrypted_value = encrypted_value[1:]
elif encrypted_value.startswith("\\x"):
encrypted_value = encrypted_value[2:]
# Convert hex string to bytes
encrypted_bytes = binascii.unhexlify(encrypted_value)
# Decrypt the bytes
decrypted_str = decrypt_bytes_to_string(encrypted_bytes)
# Parse and pretty print the decrypted JSON
decrypted_json = json.loads(decrypted_str)
print("Decrypted credential value:")
print(json.dumps(decrypted_json, indent=2))
except binascii.Error:
print("Error: Invalid hex-encoded string")
sys.exit(1)
print("Error: Invalid hex encoded string")
if key == "":
# Empty key → just decode as UTF-8, no decryption
try:
decrypted_str = raw_bytes.decode("utf-8")
except UnicodeDecodeError as e:
print(f"Error decoding bytes as UTF-8: {e}")
sys.exit(1)
else:
print(key)
try:
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=key)
except Exception as e:
print(f"Error decrypting value: {e}")
sys.exit(1)
except json.JSONDecodeError as e:
print(f"Decrypted raw value (not JSON): {e}")
# Try to pretty-print as JSON, otherwise print raw
try:
parsed = json.loads(decrypted_str)
print(json.dumps(parsed, indent=2))
except json.JSONDecodeError:
print(decrypted_str)
def main() -> None:
parser = argparse.ArgumentParser(
description="Decrypt a hex-encoded credential value."
)
parser.add_argument(
"value",
help="Hex-encoded encrypted value to decrypt.",
)
parser.add_argument(
"--key",
default=None,
help=(
"Encryption key. Omit to use ENCRYPTION_KEY_SECRET from env. "
'Pass "" (empty) to just decode as UTF-8 without decryption.'
),
)
args = parser.parse_args()
global_version.set_ee()
decrypt_raw_credential(args.value, key=args.key)
global_version.unset_ee()
except Exception as e:
print(f"Error decrypting value: {e}")
if __name__ == "__main__":
main()
if len(sys.argv) != 2:
print("Usage: python decrypt.py <hex_encoded_encrypted_value>")
sys.exit(1)
encrypted_value = sys.argv[1]
decrypt_raw_credential(encrypted_value)

View File

@@ -1,107 +0,0 @@
"""Re-encrypt secrets under the current ENCRYPTION_KEY_SECRET.
Decrypts all encrypted columns using the old key (or raw decode if the old key
is empty), then re-encrypts them with the current ENCRYPTION_KEY_SECRET.
Usage (docker):
docker exec -it onyx-api_server-1 \
python -m scripts.reencrypt_secrets --old-key "previous-key"
Usage (kubernetes):
kubectl exec -it <pod> -- \
python -m scripts.reencrypt_secrets --old-key "previous-key"
Omit --old-key (or pass "") if secrets were not previously encrypted.
For multi-tenant deployments, pass --tenant-id to target a specific tenant,
or --all-tenants to iterate every tenant.
"""
import argparse
import os
import sys
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from onyx.db.rotate_encryption_key import rotate_encryption_key # noqa: E402
from onyx.db.engine.sql_engine import get_session_with_tenant # noqa: E402
from onyx.db.engine.sql_engine import SqlEngine # noqa: E402
from onyx.db.engine.tenant_utils import get_all_tenant_ids # noqa: E402
from onyx.utils.variable_functionality import global_version # noqa: E402
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # noqa: E402
def _run_for_tenant(tenant_id: str, old_key: str | None, dry_run: bool = False) -> None:
print(f"Re-encrypting secrets for tenant: {tenant_id}")
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
results = rotate_encryption_key(db_session, old_key=old_key, dry_run=dry_run)
if results:
for col, count in results.items():
print(
f" {col}: {count} row(s) {'would be ' if dry_run else ''}re-encrypted"
)
else:
print("No rows needed re-encryption.")
def main() -> None:
parser = argparse.ArgumentParser(
description="Re-encrypt secrets under the current encryption key."
)
parser.add_argument(
"--old-key",
default=None,
help="Previous encryption key. Omit or pass empty string if not applicable.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be re-encrypted without making changes.",
)
tenant_group = parser.add_mutually_exclusive_group()
tenant_group.add_argument(
"--tenant-id",
default=None,
help="Target a specific tenant schema.",
)
tenant_group.add_argument(
"--all-tenants",
action="store_true",
help="Iterate all tenants.",
)
args = parser.parse_args()
old_key = args.old_key if args.old_key else None
global_version.set_ee()
SqlEngine.init_engine(pool_size=5, max_overflow=2)
if args.dry_run:
print("DRY RUN — no changes will be made")
if args.all_tenants:
tenant_ids = get_all_tenant_ids()
print(f"Found {len(tenant_ids)} tenant(s)")
failed_tenants: list[str] = []
for tid in tenant_ids:
try:
_run_for_tenant(tid, old_key, dry_run=args.dry_run)
except Exception as e:
print(f" ERROR for tenant {tid}: {e}")
failed_tenants.append(tid)
if failed_tenants:
print(f"FAILED tenants ({len(failed_tenants)}): {failed_tenants}")
sys.exit(1)
else:
tenant_id = args.tenant_id or POSTGRES_DEFAULT_SCHEMA
_run_for_tenant(tenant_id, old_key, dry_run=args.dry_run)
print("Done.")
if __name__ == "__main__":
main()

View File

@@ -48,7 +48,7 @@ def test_gitlab_connector_basic(gitlab_connector: GitlabConnector) -> None:
# --- Specific Document Details to Validate ---
target_mr_id = f"https://{gitlab_base_url}/{project_path}/-/merge_requests/1"
target_issue_id = f"https://{gitlab_base_url}/{project_path}/-/work_items/2"
target_issue_id = f"https://{gitlab_base_url}/{project_path}/-/issues/2"
target_code_file_semantic_id = "README.md"
# ---

View File

@@ -7,8 +7,6 @@ Verifies that:
3. Upserting is idempotent (running twice doesn't duplicate nodes)
4. Document-to-hierarchy-node linkage is updated during pruning
5. link_hierarchy_nodes_to_documents links nodes that are also documents
6. HierarchyNodeByConnectorCredentialPair join table population and pruning
7. Orphaned hierarchy node deletion and re-parenting
Uses a mock SlimConnectorWithPermSync that yields known hierarchy nodes and slim documents,
combined with a real PostgreSQL database for verifying persistence.
@@ -26,27 +24,16 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode
from onyx.connectors.models import InputType
from onyx.connectors.models import SlimDocument
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import HierarchyNodeType
from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes
from onyx.db.hierarchy import ensure_source_node_exists
from onyx.db.hierarchy import get_all_hierarchy_nodes_for_source
from onyx.db.hierarchy import get_hierarchy_node_by_raw_id
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries
from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Document as DbDocument
from onyx.db.models import HierarchyNode as DBHierarchyNode
from onyx.db.models import HierarchyNodeByConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.kg.models import KGStage
@@ -155,80 +142,13 @@ class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
# ---------------------------------------------------------------------------
def _create_cc_pair(
db_session: Session,
source: DocumentSource = TEST_SOURCE,
) -> ConnectorCredentialPair:
"""Create a real Connector + Credential + ConnectorCredentialPair for testing."""
connector = Connector(
name=f"Test {source.value} Connector",
source=source,
input_type=InputType.LOAD_STATE,
connector_specific_config={},
)
db_session.add(connector)
db_session.flush()
credential = Credential(
source=source,
credential_json={},
admin_public=True,
)
db_session.add(credential)
db_session.flush()
db_session.expire(credential)
cc_pair = ConnectorCredentialPair(
connector_id=connector.id,
credential_id=credential.id,
name=f"Test {source.value} CC Pair",
status=ConnectorCredentialPairStatus.ACTIVE,
access_type=AccessType.PUBLIC,
)
db_session.add(cc_pair)
db_session.commit()
db_session.refresh(cc_pair)
return cc_pair
def _cleanup_test_data(db_session: Session) -> None:
"""Remove all test hierarchy nodes and documents to isolate tests."""
for doc_id in SLIM_DOC_IDS:
db_session.query(DbDocument).filter(DbDocument.id == doc_id).delete()
test_connector_ids_q = db_session.query(Connector.id).filter(
Connector.source == TEST_SOURCE,
Connector.name.like("Test %"),
)
db_session.query(HierarchyNodeByConnectorCredentialPair).filter(
HierarchyNodeByConnectorCredentialPair.connector_id.in_(test_connector_ids_q)
).delete(synchronize_session="fetch")
db_session.query(DBHierarchyNode).filter(
DBHierarchyNode.source == TEST_SOURCE
).delete()
db_session.flush()
# Collect credential IDs before deleting cc_pairs (bulk query.delete()
# bypasses ORM-level cascade, so credentials won't be auto-removed).
credential_ids = [
row[0]
for row in db_session.query(ConnectorCredentialPair.credential_id)
.filter(ConnectorCredentialPair.connector_id.in_(test_connector_ids_q))
.all()
]
db_session.query(ConnectorCredentialPair).filter(
ConnectorCredentialPair.connector_id.in_(test_connector_ids_q)
).delete(synchronize_session="fetch")
db_session.query(Connector).filter(
Connector.source == TEST_SOURCE,
Connector.name.like("Test %"),
).delete(synchronize_session="fetch")
if credential_ids:
db_session.query(Credential).filter(Credential.id.in_(credential_ids)).delete(
synchronize_session="fetch"
)
db_session.commit()
@@ -259,8 +179,15 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
result = extract_ids_from_runnable_connector(connector, callback=None)
# raw_id_to_parent should contain ONLY document IDs, not hierarchy node IDs
assert result.raw_id_to_parent.keys() == set(SLIM_DOC_IDS)
# Doc IDs should include both slim doc IDs and hierarchy node raw_node_ids
# (hierarchy node IDs are added to raw_id_to_parent so they aren't pruned)
expected_ids = {
CHANNEL_A_ID,
CHANNEL_B_ID,
CHANNEL_C_ID,
*SLIM_DOC_IDS,
}
assert result.raw_id_to_parent.keys() == expected_ids
# Hierarchy nodes should be the 3 channels
assert len(result.hierarchy_nodes) == 3
@@ -468,9 +395,9 @@ def test_extraction_preserves_parent_hierarchy_raw_node_id(
result.raw_id_to_parent[doc_id] == expected_parent
), f"raw_id_to_parent[{doc_id}] should be {expected_parent}"
# Hierarchy node IDs should NOT be in raw_id_to_parent
# Hierarchy node entries have None parent (they aren't documents)
for channel_id in [CHANNEL_A_ID, CHANNEL_B_ID, CHANNEL_C_ID]:
assert channel_id not in result.raw_id_to_parent
assert result.raw_id_to_parent[channel_id] is None
def test_update_document_parent_hierarchy_nodes(db_session: Session) -> None:
@@ -638,241 +565,3 @@ def test_link_hierarchy_nodes_skips_non_hierarchy_sources(
commit=False,
)
assert linked == 0
# ---------------------------------------------------------------------------
# Join table + pruning tests
# ---------------------------------------------------------------------------
def test_upsert_hierarchy_node_cc_pair_entries(db_session: Session) -> None:
"""upsert_hierarchy_node_cc_pair_entries should insert rows and be idempotent."""
_cleanup_test_data(db_session)
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
cc_pair = _create_cc_pair(db_session)
upserted = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=_make_hierarchy_nodes(),
source=TEST_SOURCE,
commit=True,
is_connector_public=False,
)
node_ids = [n.id for n in upserted]
# First call — should insert rows
upsert_hierarchy_node_cc_pair_entries(
db_session=db_session,
hierarchy_node_ids=node_ids,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
commit=True,
)
rows = (
db_session.query(HierarchyNodeByConnectorCredentialPair)
.filter(
HierarchyNodeByConnectorCredentialPair.connector_id == cc_pair.connector_id,
HierarchyNodeByConnectorCredentialPair.credential_id
== cc_pair.credential_id,
)
.all()
)
assert len(rows) == 3
# Second call — idempotent, same count
upsert_hierarchy_node_cc_pair_entries(
db_session=db_session,
hierarchy_node_ids=node_ids,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
commit=True,
)
rows_after = (
db_session.query(HierarchyNodeByConnectorCredentialPair)
.filter(
HierarchyNodeByConnectorCredentialPair.connector_id == cc_pair.connector_id,
HierarchyNodeByConnectorCredentialPair.credential_id
== cc_pair.credential_id,
)
.all()
)
assert len(rows_after) == 3
def test_remove_stale_entries_and_delete_orphans(db_session: Session) -> None:
"""After removing stale join-table entries, orphaned hierarchy nodes should
be deleted and the SOURCE node should survive."""
_cleanup_test_data(db_session)
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
cc_pair = _create_cc_pair(db_session)
upserted = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=_make_hierarchy_nodes(),
source=TEST_SOURCE,
commit=True,
is_connector_public=False,
)
all_ids = [n.id for n in upserted]
upsert_hierarchy_node_cc_pair_entries(
db_session=db_session,
hierarchy_node_ids=all_ids,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
commit=True,
)
# Now simulate a pruning run where only channel A survived
channel_a = get_hierarchy_node_by_raw_id(db_session, CHANNEL_A_ID, TEST_SOURCE)
assert channel_a is not None
live_ids = {channel_a.id}
stale_removed = remove_stale_hierarchy_node_cc_pair_entries(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
live_hierarchy_node_ids=live_ids,
commit=True,
)
assert stale_removed == 2
# Delete orphaned nodes
deleted_raw_ids = delete_orphaned_hierarchy_nodes(
db_session=db_session,
source=TEST_SOURCE,
commit=True,
)
assert set(deleted_raw_ids) == {CHANNEL_B_ID, CHANNEL_C_ID}
# Verify only channel A + SOURCE remain
remaining = get_all_hierarchy_nodes_for_source(db_session, TEST_SOURCE)
remaining_raw = {n.raw_node_id for n in remaining}
assert remaining_raw == {CHANNEL_A_ID, source_node.raw_node_id}
def test_multi_cc_pair_prevents_premature_deletion(db_session: Session) -> None:
"""A hierarchy node shared by two cc_pairs should NOT be deleted when only
one cc_pair removes its association."""
_cleanup_test_data(db_session)
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
cc_pair_1 = _create_cc_pair(db_session)
cc_pair_2 = _create_cc_pair(db_session)
upserted = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=_make_hierarchy_nodes(),
source=TEST_SOURCE,
commit=True,
is_connector_public=False,
)
all_ids = [n.id for n in upserted]
# cc_pair 1 owns all 3
upsert_hierarchy_node_cc_pair_entries(
db_session=db_session,
hierarchy_node_ids=all_ids,
connector_id=cc_pair_1.connector_id,
credential_id=cc_pair_1.credential_id,
commit=True,
)
# cc_pair 2 also owns all 3
upsert_hierarchy_node_cc_pair_entries(
db_session=db_session,
hierarchy_node_ids=all_ids,
connector_id=cc_pair_2.connector_id,
credential_id=cc_pair_2.credential_id,
commit=True,
)
# cc_pair 1 prunes — keeps none
remove_stale_hierarchy_node_cc_pair_entries(
db_session=db_session,
connector_id=cc_pair_1.connector_id,
credential_id=cc_pair_1.credential_id,
live_hierarchy_node_ids=set(),
commit=True,
)
# Orphan deletion should find nothing because cc_pair 2 still references them
deleted = delete_orphaned_hierarchy_nodes(
db_session=db_session,
source=TEST_SOURCE,
commit=True,
)
assert deleted == []
# All 3 nodes + SOURCE should still exist
remaining = get_all_hierarchy_nodes_for_source(db_session, TEST_SOURCE)
assert len(remaining) == 4
def test_reparent_orphaned_children(db_session: Session) -> None:
"""After deleting a parent hierarchy node, its children should be
re-parented to the SOURCE node."""
_cleanup_test_data(db_session)
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
cc_pair = _create_cc_pair(db_session)
# Create a parent node and a child node
parent_node = PydanticHierarchyNode(
raw_node_id="PARENT",
raw_parent_id=None,
display_name="Parent",
node_type=HierarchyNodeType.CHANNEL,
)
child_node = PydanticHierarchyNode(
raw_node_id="CHILD",
raw_parent_id="PARENT",
display_name="Child",
node_type=HierarchyNodeType.CHANNEL,
)
upserted = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=[parent_node, child_node],
source=TEST_SOURCE,
commit=True,
is_connector_public=False,
)
assert len(upserted) == 2
parent_db = get_hierarchy_node_by_raw_id(db_session, "PARENT", TEST_SOURCE)
child_db = get_hierarchy_node_by_raw_id(db_session, "CHILD", TEST_SOURCE)
assert parent_db is not None and child_db is not None
assert child_db.parent_id == parent_db.id
# Associate only the child with a cc_pair (parent is orphaned)
upsert_hierarchy_node_cc_pair_entries(
db_session=db_session,
hierarchy_node_ids=[child_db.id],
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
commit=True,
)
# Delete orphaned nodes (parent has no cc_pair entry)
deleted = delete_orphaned_hierarchy_nodes(
db_session=db_session,
source=TEST_SOURCE,
commit=True,
)
assert "PARENT" in deleted
# Child should now have parent_id=NULL (SET NULL cascade)
db_session.expire_all()
child_db = get_hierarchy_node_by_raw_id(db_session, "CHILD", TEST_SOURCE)
assert child_db is not None
assert child_db.parent_id is None
# Re-parent orphans to SOURCE
reparented = reparent_orphaned_hierarchy_nodes(
db_session=db_session,
source=TEST_SOURCE,
commit=True,
)
assert len(reparented) == 1
db_session.expire_all()
child_db = get_hierarchy_node_by_raw_id(db_session, "CHILD", TEST_SOURCE)
assert child_db is not None
assert child_db.parent_id == source_node.id

View File

@@ -1,90 +0,0 @@
"""Test that Credential with nested JSON round-trips through SensitiveValue correctly.
Exercises the full encrypt → store → read → decrypt → SensitiveValue path
with realistic nested OAuth credential data, and verifies SQLAlchemy dirty
tracking works with nested dict comparison.
Requires a running Postgres instance.
"""
from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.db.models import Credential
from onyx.utils.sensitive import SensitiveValue
# NOTE: this is not the real shape of a Drive credential,
# but it is intended to test nested JSON credential handling
_NESTED_CRED_JSON = {
"oauth_tokens": {
"access_token": "ya29.abc123",
"refresh_token": "1//xEg-def456",
},
"scopes": ["read", "write", "admin"],
"client_config": {
"client_id": "123.apps.googleusercontent.com",
"client_secret": "GOCSPX-secret",
},
}
def test_nested_credential_json_round_trip(db_session: Session) -> None:
"""Nested OAuth credential survives encrypt → store → read → decrypt."""
credential = Credential(
source=DocumentSource.GOOGLE_DRIVE,
credential_json=_NESTED_CRED_JSON,
)
db_session.add(credential)
db_session.flush()
# Immediate read (no DB round-trip) — tests the set event wrapping
assert isinstance(credential.credential_json, SensitiveValue)
assert credential.credential_json.get_value(apply_mask=False) == _NESTED_CRED_JSON
# DB round-trip — tests process_result_value
db_session.expire(credential)
reloaded = credential.credential_json
assert isinstance(reloaded, SensitiveValue)
assert reloaded.get_value(apply_mask=False) == _NESTED_CRED_JSON
db_session.rollback()
def test_reassign_same_nested_json_not_dirty(db_session: Session) -> None:
"""Re-assigning the same nested dict should not mark the session dirty."""
credential = Credential(
source=DocumentSource.GOOGLE_DRIVE,
credential_json=_NESTED_CRED_JSON,
)
db_session.add(credential)
db_session.flush()
# Clear dirty state from the insert
db_session.expire(credential)
_ = credential.credential_json # force reload
# Re-assign identical value
credential.credential_json = _NESTED_CRED_JSON # type: ignore[assignment]
assert not db_session.is_modified(credential)
db_session.rollback()
def test_assign_different_nested_json_is_dirty(db_session: Session) -> None:
"""Assigning a different nested dict should mark the session dirty."""
credential = Credential(
source=DocumentSource.GOOGLE_DRIVE,
credential_json=_NESTED_CRED_JSON,
)
db_session.add(credential)
db_session.flush()
db_session.expire(credential)
_ = credential.credential_json # force reload
modified_cred = {**_NESTED_CRED_JSON, "scopes": ["read"]}
credential.credential_json = modified_cred # type: ignore[assignment]
assert db_session.is_modified(credential)
db_session.rollback()

View File

@@ -1,305 +0,0 @@
"""Tests for rotate_encryption_key against real Postgres.
Uses real ORM models (Credential, InternetSearchProvider) and the actual
Postgres database. Discovery is mocked in rotation tests to scope mutations
to only the test rows — the real _discover_encrypted_columns walk is tested
separately in TestDiscoverEncryptedColumns.
Requires a running Postgres instance. Run with::
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/db/test_rotate_encryption_key.py
"""
import json
from collections.abc import Generator
from unittest.mock import patch
import pytest
from sqlalchemy import LargeBinary
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import Session
from ee.onyx.utils.encryption import _decrypt_bytes
from ee.onyx.utils.encryption import _encrypt_string
from ee.onyx.utils.encryption import _get_trimmed_key
from onyx.configs.constants import DocumentSource
from onyx.db.models import Credential
from onyx.db.models import EncryptedJson
from onyx.db.models import EncryptedString
from onyx.db.models import InternetSearchProvider
from onyx.db.rotate_encryption_key import _discover_encrypted_columns
from onyx.db.rotate_encryption_key import rotate_encryption_key
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version
EE_MODULE = "ee.onyx.utils.encryption"
ROTATE_MODULE = "onyx.db.rotate_encryption_key"
OLD_KEY = "o" * 16
NEW_KEY = "n" * 16
@pytest.fixture(autouse=True)
def _enable_ee() -> Generator[None, None, None]:
prev = global_version._is_ee
global_version.set_ee()
fetch_versioned_implementation.cache_clear()
yield
global_version._is_ee = prev
fetch_versioned_implementation.cache_clear()
@pytest.fixture(autouse=True)
def _clear_key_cache() -> None:
_get_trimmed_key.cache_clear()
def _raw_credential_bytes(db_session: Session, credential_id: int) -> bytes | None:
"""Read raw bytes from credential_json, bypassing the TypeDecorator."""
col = Credential.__table__.c.credential_json
stmt = select(col.cast(LargeBinary)).where(
Credential.__table__.c.id == credential_id
)
return db_session.execute(stmt).scalar()
def _raw_isp_bytes(db_session: Session, isp_id: int) -> bytes | None:
"""Read raw bytes from InternetSearchProvider.api_key."""
col = InternetSearchProvider.__table__.c.api_key
stmt = select(col.cast(LargeBinary)).where(
InternetSearchProvider.__table__.c.id == isp_id
)
return db_session.execute(stmt).scalar()
class TestDiscoverEncryptedColumns:
"""Verify _discover_encrypted_columns finds real production models."""
def test_discovers_credential_json(self) -> None:
results = _discover_encrypted_columns()
found = {
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
for model_cls, col_name, _, is_json in results
}
assert ("credential", "credential_json", True) in found
def test_discovers_internet_search_provider_api_key(self) -> None:
results = _discover_encrypted_columns()
found = {
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
for model_cls, col_name, _, is_json in results
}
assert ("internet_search_provider", "api_key", False) in found
def test_all_encrypted_string_columns_are_not_json(self) -> None:
results = _discover_encrypted_columns()
for model_cls, col_name, _, is_json in results:
col = getattr(model_cls, col_name).property.columns[0]
if isinstance(col.type, EncryptedString):
assert not is_json, (
f"{model_cls.__tablename__}.{col_name} is EncryptedString " # type: ignore[attr-defined]
f"but is_json={is_json}"
)
def test_all_encrypted_json_columns_are_json(self) -> None:
results = _discover_encrypted_columns()
for model_cls, col_name, _, is_json in results:
col = getattr(model_cls, col_name).property.columns[0]
if isinstance(col.type, EncryptedJson):
assert is_json, (
f"{model_cls.__tablename__}.{col_name} is EncryptedJson " # type: ignore[attr-defined]
f"but is_json={is_json}"
)
class TestRotateCredential:
"""Test rotation against the real Credential table (EncryptedJson).
Discovery is scoped to only the Credential model to avoid mutating
other tables in the test database.
"""
@pytest.fixture(autouse=True)
def _limit_discovery(self) -> Generator[None, None, None]:
with patch(
f"{ROTATE_MODULE}._discover_encrypted_columns",
return_value=[(Credential, "credential_json", ["id"], True)],
):
yield
@pytest.fixture()
def credential_id(
self, db_session: Session, tenant_context: None # noqa: ARG002
) -> Generator[int, None, None]:
"""Insert a Credential row with raw encrypted bytes, clean up after."""
config = {"api_key": "sk-test-1234", "endpoint": "https://example.com"}
encrypted = _encrypt_string(json.dumps(config), key=OLD_KEY)
result = db_session.execute(
text(
"INSERT INTO credential "
"(source, credential_json, admin_public, curator_public) "
"VALUES (:source, :cred_json, true, false) "
"RETURNING id"
),
{"source": DocumentSource.INGESTION_API.value, "cred_json": encrypted},
)
cred_id = result.scalar_one()
db_session.commit()
yield cred_id
db_session.execute(
text("DELETE FROM credential WHERE id = :id"), {"id": cred_id}
)
db_session.commit()
def test_rotates_credential_json(
self, db_session: Session, credential_id: int
) -> None:
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
assert totals.get("credential.credential_json", 0) >= 1
raw = _raw_credential_bytes(db_session, credential_id)
assert raw is not None
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
assert decrypted["api_key"] == "sk-test-1234"
assert decrypted["endpoint"] == "https://example.com"
def test_skips_already_rotated(
self, db_session: Session, credential_id: int
) -> None:
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
rotate_encryption_key(db_session, old_key=OLD_KEY)
_ = rotate_encryption_key(db_session, old_key=OLD_KEY)
raw = _raw_credential_bytes(db_session, credential_id)
assert raw is not None
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
assert decrypted["api_key"] == "sk-test-1234"
def test_dry_run_does_not_modify(
self, db_session: Session, credential_id: int
) -> None:
original = _raw_credential_bytes(db_session, credential_id)
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
totals = rotate_encryption_key(db_session, old_key=OLD_KEY, dry_run=True)
assert totals.get("credential.credential_json", 0) >= 1
raw_after = _raw_credential_bytes(db_session, credential_id)
assert raw_after == original
class TestRotateInternetSearchProvider:
"""Test rotation against the real InternetSearchProvider table (EncryptedString).
Discovery is scoped to only the InternetSearchProvider model to avoid
mutating other tables in the test database.
"""
@pytest.fixture(autouse=True)
def _limit_discovery(self) -> Generator[None, None, None]:
with patch(
f"{ROTATE_MODULE}._discover_encrypted_columns",
return_value=[
(InternetSearchProvider, "api_key", ["id"], False),
],
):
yield
@pytest.fixture()
def isp_id(
self, db_session: Session, tenant_context: None # noqa: ARG002
) -> Generator[int, None, None]:
"""Insert an InternetSearchProvider row with raw encrypted bytes."""
encrypted = _encrypt_string("sk-secret-api-key", key=OLD_KEY)
result = db_session.execute(
text(
"INSERT INTO internet_search_provider "
"(name, provider_type, api_key, is_active) "
"VALUES (:name, :ptype, :api_key, false) "
"RETURNING id"
),
{
"name": f"test-rotation-{id(self)}",
"ptype": "test",
"api_key": encrypted,
},
)
isp_id = result.scalar_one()
db_session.commit()
yield isp_id
db_session.execute(
text("DELETE FROM internet_search_provider WHERE id = :id"),
{"id": isp_id},
)
db_session.commit()
def test_rotates_api_key(self, db_session: Session, isp_id: int) -> None:
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
assert totals.get("internet_search_provider.api_key", 0) >= 1
raw = _raw_isp_bytes(db_session, isp_id)
assert raw is not None
assert _decrypt_bytes(raw, key=NEW_KEY) == "sk-secret-api-key"
def test_rotates_from_unencrypted(
self, db_session: Session, tenant_context: None # noqa: ARG002
) -> None:
"""Test rotating data that was stored without any encryption key."""
result = db_session.execute(
text(
"INSERT INTO internet_search_provider "
"(name, provider_type, api_key, is_active) "
"VALUES (:name, :ptype, :api_key, false) "
"RETURNING id"
),
{
"name": f"test-raw-{id(self)}",
"ptype": "test",
"api_key": b"raw-api-key",
},
)
isp_id = result.scalar_one()
db_session.commit()
try:
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
totals = rotate_encryption_key(db_session, old_key=None)
assert totals.get("internet_search_provider.api_key", 0) >= 1
raw = _raw_isp_bytes(db_session, isp_id)
assert raw is not None
assert _decrypt_bytes(raw, key=NEW_KEY) == "raw-api-key"
finally:
db_session.execute(
text("DELETE FROM internet_search_provider WHERE id = :id"),
{"id": isp_id},
)
db_session.commit()

View File

@@ -85,7 +85,7 @@ def test_group_overlap_filter(
results = _get_accessible_hierarchy_nodes_for_source(
db_session,
source=DocumentSource.GOOGLE_DRIVE,
user_email="",
user_email=None,
external_group_ids=["group_engineering"],
)
result_ids = {n.raw_node_id for n in results}
@@ -124,7 +124,7 @@ def test_no_credentials_returns_only_public(
results = _get_accessible_hierarchy_nodes_for_source(
db_session,
source=DocumentSource.GOOGLE_DRIVE,
user_email="",
user_email=None,
external_group_ids=[],
)
result_ids = {n.raw_node_id for n in results}

View File

@@ -1,85 +0,0 @@
"""Tests that SlackBot CRUD operations return properly typed SensitiveValue fields.
Regression test for the bug where insert_slack_bot/update_slack_bot returned
objects with raw string tokens instead of SensitiveValue wrappers, causing
'str object has no attribute get_value' errors in SlackBot.from_model().
"""
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.db.slack_bot import insert_slack_bot
from onyx.db.slack_bot import update_slack_bot
from onyx.server.manage.models import SlackBot
from onyx.utils.sensitive import SensitiveValue
def _unique(prefix: str) -> str:
return f"{prefix}-{uuid4().hex[:8]}"
def test_insert_slack_bot_returns_sensitive_values(db_session: Session) -> None:
bot_token = _unique("xoxb-insert")
app_token = _unique("xapp-insert")
user_token = _unique("xoxp-insert")
slack_bot = insert_slack_bot(
db_session=db_session,
name=_unique("test-bot-insert"),
enabled=True,
bot_token=bot_token,
app_token=app_token,
user_token=user_token,
)
assert isinstance(slack_bot.bot_token, SensitiveValue)
assert isinstance(slack_bot.app_token, SensitiveValue)
assert isinstance(slack_bot.user_token, SensitiveValue)
assert slack_bot.bot_token.get_value(apply_mask=False) == bot_token
assert slack_bot.app_token.get_value(apply_mask=False) == app_token
assert slack_bot.user_token.get_value(apply_mask=False) == user_token
# Verify from_model works without error
pydantic_bot = SlackBot.from_model(slack_bot)
assert pydantic_bot.bot_token # masked, but not empty
assert pydantic_bot.app_token
def test_update_slack_bot_returns_sensitive_values(db_session: Session) -> None:
slack_bot = insert_slack_bot(
db_session=db_session,
name=_unique("test-bot-update"),
enabled=True,
bot_token=_unique("xoxb-update"),
app_token=_unique("xapp-update"),
)
new_bot_token = _unique("xoxb-update-new")
new_app_token = _unique("xapp-update-new")
new_user_token = _unique("xoxp-update-new")
updated = update_slack_bot(
db_session=db_session,
slack_bot_id=slack_bot.id,
name=_unique("test-bot-updated"),
enabled=False,
bot_token=new_bot_token,
app_token=new_app_token,
user_token=new_user_token,
)
assert isinstance(updated.bot_token, SensitiveValue)
assert isinstance(updated.app_token, SensitiveValue)
assert isinstance(updated.user_token, SensitiveValue)
assert updated.bot_token.get_value(apply_mask=False) == new_bot_token
assert updated.app_token.get_value(apply_mask=False) == new_app_token
assert updated.user_token.get_value(apply_mask=False) == new_user_token
# Verify from_model works without error
pydantic_bot = SlackBot.from_model(updated)
assert pydantic_bot.bot_token
assert pydantic_bot.app_token
assert pydantic_bot.user_token is not None

View File

@@ -148,16 +148,8 @@ class TestOAuthConfigCRUD:
)
# Secrets should be preserved
assert updated_config.client_id is not None
assert original_client_id is not None
assert updated_config.client_id.get_value(
apply_mask=False
) == original_client_id.get_value(apply_mask=False)
assert updated_config.client_secret is not None
assert original_client_secret is not None
assert updated_config.client_secret.get_value(
apply_mask=False
) == original_client_secret.get_value(apply_mask=False)
assert updated_config.client_id == original_client_id
assert updated_config.client_secret == original_client_secret
# But name should be updated
assert updated_config.name == new_name
@@ -181,14 +173,9 @@ class TestOAuthConfigCRUD:
)
# client_id should be cleared (empty string)
assert updated_config.client_id is not None
assert updated_config.client_id.get_value(apply_mask=False) == ""
assert updated_config.client_id == ""
# client_secret should be preserved
assert updated_config.client_secret is not None
assert original_client_secret is not None
assert updated_config.client_secret.get_value(
apply_mask=False
) == original_client_secret.get_value(apply_mask=False)
assert updated_config.client_secret == original_client_secret
def test_update_oauth_config_clear_client_secret(self, db_session: Session) -> None:
"""Test clearing client_secret while preserving client_id"""
@@ -203,14 +190,9 @@ class TestOAuthConfigCRUD:
)
# client_secret should be cleared (empty string)
assert updated_config.client_secret is not None
assert updated_config.client_secret.get_value(apply_mask=False) == ""
assert updated_config.client_secret == ""
# client_id should be preserved
assert updated_config.client_id is not None
assert original_client_id is not None
assert updated_config.client_id.get_value(
apply_mask=False
) == original_client_id.get_value(apply_mask=False)
assert updated_config.client_id == original_client_id
def test_update_oauth_config_clear_both_secrets(self, db_session: Session) -> None:
"""Test clearing both client_id and client_secret"""
@@ -225,10 +207,8 @@ class TestOAuthConfigCRUD:
)
# Both should be cleared (empty strings)
assert updated_config.client_id is not None
assert updated_config.client_id.get_value(apply_mask=False) == ""
assert updated_config.client_secret is not None
assert updated_config.client_secret.get_value(apply_mask=False) == ""
assert updated_config.client_id == ""
assert updated_config.client_secret == ""
def test_update_oauth_config_authorization_url(self, db_session: Session) -> None:
"""Test updating authorization_url"""
@@ -295,8 +275,7 @@ class TestOAuthConfigCRUD:
assert updated_config.token_url == new_token_url
assert updated_config.scopes == new_scopes
assert updated_config.additional_params == new_params
assert updated_config.client_id is not None
assert updated_config.client_id.get_value(apply_mask=False) == new_client_id
assert updated_config.client_id == new_client_id
def test_delete_oauth_config(self, db_session: Session) -> None:
"""Test deleting an OAuth configuration"""
@@ -437,8 +416,7 @@ class TestOAuthUserTokenCRUD:
assert user_token.id is not None
assert user_token.oauth_config_id == oauth_config.id
assert user_token.user_id == user.id
assert user_token.token_data is not None
assert user_token.token_data.get_value(apply_mask=False) == token_data
assert user_token.token_data == token_data
assert user_token.created_at is not None
assert user_token.updated_at is not None
@@ -468,13 +446,8 @@ class TestOAuthUserTokenCRUD:
# Should be the same token record (updated, not inserted)
assert updated_token.id == initial_token_id
assert updated_token.token_data is not None
assert (
updated_token.token_data.get_value(apply_mask=False) == updated_token_data
)
assert (
updated_token.token_data.get_value(apply_mask=False) != initial_token_data
)
assert updated_token.token_data == updated_token_data
assert updated_token.token_data != initial_token_data
def test_get_user_oauth_token(self, db_session: Session) -> None:
"""Test retrieving a user's OAuth token"""
@@ -490,8 +463,7 @@ class TestOAuthUserTokenCRUD:
assert retrieved_token is not None
assert retrieved_token.id == created_token.id
assert retrieved_token.token_data is not None
assert retrieved_token.token_data.get_value(apply_mask=False) == token_data
assert retrieved_token.token_data == token_data
def test_get_user_oauth_token_not_found(self, db_session: Session) -> None:
"""Test retrieving a non-existent user token returns None"""
@@ -547,8 +519,7 @@ class TestOAuthUserTokenCRUD:
retrieved_token = get_user_oauth_token(oauth_config.id, user.id, db_session)
assert retrieved_token is not None
assert retrieved_token.id == updated_token.id
assert retrieved_token.token_data is not None
assert retrieved_token.token_data.get_value(apply_mask=False) == token_data2
assert retrieved_token.token_data == token_data2
def test_cascade_delete_user_tokens_on_config_deletion(
self, db_session: Session

View File

@@ -374,14 +374,8 @@ class TestOAuthTokenManagerCodeExchange:
assert call_args[0][0] == oauth_config.token_url
assert call_args[1]["data"]["grant_type"] == "authorization_code"
assert call_args[1]["data"]["code"] == "auth_code_123"
assert oauth_config.client_id is not None
assert oauth_config.client_secret is not None
assert call_args[1]["data"]["client_id"] == oauth_config.client_id.get_value(
apply_mask=False
)
assert call_args[1]["data"][
"client_secret"
] == oauth_config.client_secret.get_value(apply_mask=False)
assert call_args[1]["data"]["client_id"] == oauth_config.client_id
assert call_args[1]["data"]["client_secret"] == oauth_config.client_secret
assert call_args[1]["data"]["redirect_uri"] == "https://example.com/callback"
@patch("onyx.auth.oauth_token_manager.requests.post")

View File

@@ -950,7 +950,6 @@ from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from tests.external_dependency_unit.answer.stream_test_builder import StreamTestBuilder
from tests.external_dependency_unit.answer.stream_test_utils import create_chat_session
@@ -1292,21 +1291,12 @@ def test_code_interpreter_replay_packets_include_code_and_output(
tool_call_id="call_replay_test",
tool_call_argument_tokens=[json.dumps({"code": code})],
)
).expect(
Packet(
placement=create_placement(0),
obj=ToolCallArgumentDelta(
tool_type="python",
argument_deltas={"code": code},
),
),
forward=2,
).expect(
Packet(
placement=create_placement(0),
obj=PythonToolStart(code=code),
),
forward=False,
forward=2,
).expect(
Packet(
placement=create_placement(0),

View File

@@ -64,8 +64,7 @@ class TestBotConfigAPI:
db_session.commit()
assert config is not None
assert config.bot_token is not None
assert config.bot_token.get_value(apply_mask=False) == "test_token_123"
assert config.bot_token == "test_token_123"
# Cleanup
delete_discord_bot_config(db_session)

View File

@@ -1,165 +0,0 @@
"""Tests for EE AES-CBC encryption/decryption with explicit key support.
With EE mode enabled (via conftest), fetch_versioned_implementation resolves
to the EE implementations, so no patching of the MIT layer is needed.
"""
from unittest.mock import patch
import pytest
from ee.onyx.utils.encryption import _decrypt_bytes
from ee.onyx.utils.encryption import _encrypt_string
from ee.onyx.utils.encryption import _get_trimmed_key
from ee.onyx.utils.encryption import decrypt_bytes_to_string
from ee.onyx.utils.encryption import encrypt_string_to_bytes
EE_MODULE = "ee.onyx.utils.encryption"
# Keys must be exactly 16, 24, or 32 bytes for AES
KEY_16 = "a" * 16
KEY_16_ALT = "b" * 16
KEY_24 = "d" * 24
KEY_32 = "c" * 32
@pytest.fixture(autouse=True)
def _clear_key_cache() -> None:
_get_trimmed_key.cache_clear()
class TestEncryptDecryptRoundTrip:
def test_roundtrip_with_env_key(self) -> None:
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
encrypted = _encrypt_string("hello world")
assert encrypted != b"hello world"
assert _decrypt_bytes(encrypted) == "hello world"
def test_roundtrip_with_explicit_key(self) -> None:
encrypted = _encrypt_string("secret data", key=KEY_32)
assert encrypted != b"secret data"
assert _decrypt_bytes(encrypted, key=KEY_32) == "secret data"
def test_roundtrip_no_key(self) -> None:
"""Without any key, data is raw-encoded (no encryption)."""
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", ""):
encrypted = _encrypt_string("plain text")
assert encrypted == b"plain text"
assert _decrypt_bytes(encrypted) == "plain text"
def test_explicit_key_overrides_env(self) -> None:
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
encrypted = _encrypt_string("data", key=KEY_16_ALT)
with pytest.raises(ValueError):
_decrypt_bytes(encrypted, key=KEY_16)
assert _decrypt_bytes(encrypted, key=KEY_16_ALT) == "data"
def test_different_encryptions_produce_different_bytes(self) -> None:
"""Each encryption uses a random IV, so results differ."""
a = _encrypt_string("same", key=KEY_16)
b = _encrypt_string("same", key=KEY_16)
assert a != b
def test_roundtrip_empty_string(self) -> None:
encrypted = _encrypt_string("", key=KEY_16)
assert encrypted != b""
assert _decrypt_bytes(encrypted, key=KEY_16) == ""
def test_roundtrip_unicode(self) -> None:
text = "日本語テスト 🔐 émojis"
encrypted = _encrypt_string(text, key=KEY_16)
assert _decrypt_bytes(encrypted, key=KEY_16) == text
class TestDecryptFallbackBehavior:
def test_wrong_env_key_falls_back_to_raw_decode(self) -> None:
"""Default key path: AES fails on non-AES data → fallback to raw decode."""
raw = "readable text".encode()
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
assert _decrypt_bytes(raw) == "readable text"
def test_explicit_wrong_key_raises(self) -> None:
"""Explicit key path: AES fails → raises, no fallback."""
encrypted = _encrypt_string("secret", key=KEY_16)
with pytest.raises(ValueError):
_decrypt_bytes(encrypted, key=KEY_16_ALT)
def test_explicit_none_key_with_no_env(self) -> None:
"""key=None with empty env → raw decode."""
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", ""):
assert _decrypt_bytes(b"hello", key=None) == "hello"
def test_explicit_empty_string_key(self) -> None:
"""key='' means no encryption."""
encrypted = _encrypt_string("test", key="")
assert encrypted == b"test"
assert _decrypt_bytes(encrypted, key="") == "test"
class TestKeyValidation:
def test_key_too_short_raises(self) -> None:
with pytest.raises(RuntimeError, match="too short"):
_encrypt_string("data", key="short")
def test_16_byte_key(self) -> None:
encrypted = _encrypt_string("data", key=KEY_16)
assert _decrypt_bytes(encrypted, key=KEY_16) == "data"
def test_24_byte_key(self) -> None:
encrypted = _encrypt_string("data", key=KEY_24)
assert _decrypt_bytes(encrypted, key=KEY_24) == "data"
def test_32_byte_key(self) -> None:
encrypted = _encrypt_string("data", key=KEY_32)
assert _decrypt_bytes(encrypted, key=KEY_32) == "data"
def test_long_key_truncated_to_32(self) -> None:
"""Keys longer than 32 bytes are truncated to 32."""
long_key = "e" * 64
encrypted = _encrypt_string("data", key=long_key)
assert _decrypt_bytes(encrypted, key=long_key) == "data"
def test_20_byte_key_trimmed_to_16(self) -> None:
"""A 20-byte key is trimmed to the largest valid AES size that fits (16)."""
key_20 = "f" * 20
encrypted = _encrypt_string("data", key=key_20)
assert _decrypt_bytes(encrypted, key=key_20) == "data"
# Verify it was trimmed to 16 by checking that the first 16 bytes
# of the key can also decrypt it
key_16_same_prefix = "f" * 16
assert _decrypt_bytes(encrypted, key=key_16_same_prefix) == "data"
def test_25_byte_key_trimmed_to_24(self) -> None:
"""A 25-byte key is trimmed to the largest valid AES size that fits (24)."""
key_25 = "g" * 25
encrypted = _encrypt_string("data", key=key_25)
assert _decrypt_bytes(encrypted, key=key_25) == "data"
key_24_same_prefix = "g" * 24
assert _decrypt_bytes(encrypted, key=key_24_same_prefix) == "data"
def test_30_byte_key_trimmed_to_24(self) -> None:
"""A 30-byte key is trimmed to the largest valid AES size that fits (24)."""
key_30 = "h" * 30
encrypted = _encrypt_string("data", key=key_30)
assert _decrypt_bytes(encrypted, key=key_30) == "data"
key_24_same_prefix = "h" * 24
assert _decrypt_bytes(encrypted, key=key_24_same_prefix) == "data"
class TestWrapperFunctions:
"""Test encrypt_string_to_bytes / decrypt_bytes_to_string pass key through.
With EE mode enabled, the wrappers resolve to EE implementations automatically.
"""
def test_wrapper_passes_key(self) -> None:
encrypted = encrypt_string_to_bytes("payload", key=KEY_16)
assert decrypt_bytes_to_string(encrypted, key=KEY_16) == "payload"
def test_wrapper_no_key_uses_env(self) -> None:
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_32):
encrypted = encrypt_string_to_bytes("payload")
assert decrypt_bytes_to_string(encrypted) == "payload"

View File

@@ -1,163 +0,0 @@
"""Tests for user file ACL computation, including shared persona access."""
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
from onyx.access.access import collect_user_file_access
from onyx.access.access import get_access_for_user_files_impl
from onyx.access.utils import prefix_user_email
from onyx.configs.constants import PUBLIC_DOC_PAT
def _make_user(email: str) -> MagicMock:
user = MagicMock()
user.email = email
user.id = uuid4()
return user
def _make_persona(
*,
owner: MagicMock | None = None,
shared_users: list[MagicMock] | None = None,
is_public: bool = False,
deleted: bool = False,
) -> MagicMock:
persona = MagicMock()
persona.deleted = deleted
persona.is_public = is_public
persona.user_id = owner.id if owner else None
persona.user = owner
persona.users = shared_users or []
return persona
def _make_user_file(
*,
owner: MagicMock,
assistants: list[MagicMock] | None = None,
) -> MagicMock:
uf = MagicMock()
uf.id = uuid4()
uf.user = owner
uf.user_id = owner.id
uf.assistants = assistants or []
return uf
class TestCollectUserFileAccess:
def test_owner_only(self) -> None:
owner = _make_user("owner@test.com")
uf = _make_user_file(owner=owner)
emails, is_public = collect_user_file_access(uf)
assert emails == {"owner@test.com"}
assert is_public is False
def test_shared_persona_adds_users(self) -> None:
owner = _make_user("owner@test.com")
shared = _make_user("shared@test.com")
persona = _make_persona(owner=owner, shared_users=[shared])
uf = _make_user_file(owner=owner, assistants=[persona])
emails, is_public = collect_user_file_access(uf)
assert emails == {"owner@test.com", "shared@test.com"}
assert is_public is False
def test_persona_owner_added(self) -> None:
"""Persona owner (different from file owner) gets access too."""
file_owner = _make_user("file-owner@test.com")
persona_owner = _make_user("persona-owner@test.com")
persona = _make_persona(owner=persona_owner)
uf = _make_user_file(owner=file_owner, assistants=[persona])
emails, is_public = collect_user_file_access(uf)
assert "file-owner@test.com" in emails
assert "persona-owner@test.com" in emails
def test_public_persona_makes_file_public(self) -> None:
owner = _make_user("owner@test.com")
persona = _make_persona(owner=owner, is_public=True)
uf = _make_user_file(owner=owner, assistants=[persona])
emails, is_public = collect_user_file_access(uf)
assert is_public is True
assert "owner@test.com" in emails
def test_deleted_persona_ignored(self) -> None:
owner = _make_user("owner@test.com")
shared = _make_user("shared@test.com")
persona = _make_persona(owner=owner, shared_users=[shared], deleted=True)
uf = _make_user_file(owner=owner, assistants=[persona])
emails, is_public = collect_user_file_access(uf)
assert emails == {"owner@test.com"}
assert is_public is False
def test_multiple_personas_combine(self) -> None:
owner = _make_user("owner@test.com")
user_a = _make_user("a@test.com")
user_b = _make_user("b@test.com")
p1 = _make_persona(owner=owner, shared_users=[user_a])
p2 = _make_persona(owner=owner, shared_users=[user_b])
uf = _make_user_file(owner=owner, assistants=[p1, p2])
emails, is_public = collect_user_file_access(uf)
assert emails == {"owner@test.com", "a@test.com", "b@test.com"}
def test_deduplication(self) -> None:
owner = _make_user("owner@test.com")
shared = _make_user("shared@test.com")
p1 = _make_persona(owner=owner, shared_users=[shared])
p2 = _make_persona(owner=owner, shared_users=[shared])
uf = _make_user_file(owner=owner, assistants=[p1, p2])
emails, _ = collect_user_file_access(uf)
assert emails == {"owner@test.com", "shared@test.com"}
class TestGetAccessForUserFiles:
def test_shared_user_in_acl(self) -> None:
"""Shared persona users should appear in the ACL."""
owner = _make_user("owner@test.com")
shared = _make_user("shared@test.com")
persona = _make_persona(owner=owner, shared_users=[shared])
uf = _make_user_file(owner=owner, assistants=[persona])
db_session = MagicMock()
with patch(
"onyx.access.access.fetch_user_files_with_access_relationships",
return_value=[uf],
):
result = get_access_for_user_files_impl([str(uf.id)], db_session)
access = result[str(uf.id)]
acl = access.to_acl()
assert prefix_user_email("owner@test.com") in acl
assert prefix_user_email("shared@test.com") in acl
assert access.is_public is False
def test_public_persona_sets_public_acl(self) -> None:
owner = _make_user("owner@test.com")
persona = _make_persona(owner=owner, is_public=True)
uf = _make_user_file(owner=owner, assistants=[persona])
db_session = MagicMock()
with patch(
"onyx.access.access.fetch_user_files_with_access_relationships",
return_value=[uf],
):
result = get_access_for_user_files_impl([str(uf.id)], db_session)
access = result[str(uf.id)]
assert access.is_public is True
acl = access.to_acl()
assert PUBLIC_DOC_PAT in acl

View File

@@ -1,54 +0,0 @@
from unittest.mock import MagicMock
import pytest
import onyx.auth.users as users
from onyx.auth.users import verify_auth_setting
from onyx.configs.constants import AuthType
def test_verify_auth_setting_raises_for_cloud(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Cloud auth type is not valid for self-hosted deployments."""
monkeypatch.setenv("AUTH_TYPE", "cloud")
with pytest.raises(ValueError, match="'cloud' is not a valid auth type"):
verify_auth_setting()
def test_verify_auth_setting_warns_for_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Disabled auth type logs a deprecation warning."""
monkeypatch.setenv("AUTH_TYPE", "disabled")
mock_logger = MagicMock()
monkeypatch.setattr(users, "logger", mock_logger)
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.BASIC)
verify_auth_setting()
mock_logger.warning.assert_called_once()
assert "no longer supported" in mock_logger.warning.call_args[0][0]
@pytest.mark.parametrize(
"auth_type",
[AuthType.BASIC, AuthType.GOOGLE_OAUTH, AuthType.OIDC, AuthType.SAML],
)
def test_verify_auth_setting_valid_auth_types(
monkeypatch: pytest.MonkeyPatch,
auth_type: AuthType,
) -> None:
"""Valid auth types work without errors or warnings."""
monkeypatch.setenv("AUTH_TYPE", auth_type.value)
mock_logger = MagicMock()
monkeypatch.setattr(users, "logger", mock_logger)
monkeypatch.setattr(users, "AUTH_TYPE", auth_type)
verify_auth_setting()
mock_logger.warning.assert_not_called()
mock_logger.notice.assert_called_once_with(f"Using Auth Type: {auth_type.value}")

View File

@@ -27,6 +27,7 @@ def _mock_session_returning_none() -> MagicMock:
"""Return a mock session whose .get() returns None (file not found)."""
session = MagicMock()
session.get.return_value = None
session.execute.return_value.scalar_one_or_none.return_value = None
return session
@@ -219,10 +220,6 @@ class TestDeleteUserFileImpl:
# ------------------------------------------------------------------
@patch(
f"{TASKS_MODULE}.fetch_user_files_with_access_relationships",
return_value=[],
)
class TestProjectSyncUserFileImpl:
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
@patch(f"{TASKS_MODULE}.get_redis_client")
@@ -230,7 +227,6 @@ class TestProjectSyncUserFileImpl:
self,
mock_get_redis: MagicMock,
mock_get_session: MagicMock,
_mock_fetch: MagicMock,
) -> None:
redis_client = MagicMock()
lock = MagicMock()
@@ -259,7 +255,6 @@ class TestProjectSyncUserFileImpl:
self,
mock_get_redis: MagicMock,
mock_get_session: MagicMock,
_mock_fetch: MagicMock,
) -> None:
redis_client = MagicMock()
lock = MagicMock()
@@ -282,7 +277,6 @@ class TestProjectSyncUserFileImpl:
self,
mock_get_redis: MagicMock,
mock_get_session: MagicMock,
_mock_fetch: MagicMock,
) -> None:
session = _mock_session_returning_none()
mock_get_session.return_value.__enter__.return_value = session

View File

@@ -379,13 +379,10 @@ class TestProjectSyncImplNoVectorDb:
) -> None:
uf = _make_user_file(status=UserFileStatus.COMPLETED)
session = MagicMock()
session.execute.return_value.scalar_one_or_none.return_value = uf
mock_get_session.return_value.__enter__.return_value = session
with (
patch(
f"{TASKS_MODULE}.fetch_user_files_with_access_relationships",
return_value=[uf],
),
patch(f"{TASKS_MODULE}.get_all_document_indices") as mock_get_indices,
patch(f"{TASKS_MODULE}.get_active_search_settings") as mock_get_ss,
patch(f"{TASKS_MODULE}.httpx_init_vespa_pool") as mock_vespa_pool,
@@ -408,17 +405,14 @@ class TestProjectSyncImplNoVectorDb:
) -> None:
uf = _make_user_file(status=UserFileStatus.COMPLETED)
session = MagicMock()
session.execute.return_value.scalar_one_or_none.return_value = uf
mock_get_session.return_value.__enter__.return_value = session
with patch(
f"{TASKS_MODULE}.fetch_user_files_with_access_relationships",
return_value=[uf],
):
project_sync_user_file_impl(
user_file_id=str(uf.id),
tenant_id="test-tenant",
redis_locking=False,
)
project_sync_user_file_impl(
user_file_id=str(uf.id),
tenant_id="test-tenant",
redis_locking=False,
)
assert uf.needs_project_sync is False
assert uf.needs_persona_sync is False

View File

@@ -1,630 +0,0 @@
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
from onyx.utils.jsonriver import Parser
def _make_tool_call_delta(
index: int = 0,
name: str | None = None,
arguments: str | None = None,
function_is_none: bool = False,
) -> MagicMock:
"""Create a mock tool_call_delta matching the LiteLLM streaming shape."""
delta = MagicMock()
delta.index = index
if function_is_none:
delta.function = None
else:
delta.function = MagicMock()
delta.function.name = name
delta.function.arguments = arguments
return delta
def _make_placement() -> Placement:
return Placement(turn_index=0, tab_index=0)
def _mock_tool_class(emit: bool = True) -> MagicMock:
cls = MagicMock()
cls.should_emit_argument_deltas.return_value = emit
return cls
def _collect(
tc_map: dict[int, dict[str, Any]],
delta: MagicMock,
placement: Placement | None = None,
parsers: dict[int, Parser] | None = None,
) -> list[Any]:
"""Run maybe_emit_argument_delta and return the yielded packets."""
return list(
maybe_emit_argument_delta(
tc_map,
delta,
placement or _make_placement(),
parsers if parsers is not None else {},
)
)
def _stream_fragments(
fragments: list[str],
tc_map: dict[int, dict[str, Any]],
placement: Placement | None = None,
) -> list[str]:
"""Feed fragments into maybe_emit_argument_delta one by one, returning
all emitted content values concatenated per-key as a flat list."""
pl = placement or _make_placement()
parsers: dict[int, Parser] = {}
emitted: list[str] = []
for frag in fragments:
tc_map[0]["arguments"] += frag
delta = _make_tool_call_delta(arguments=frag)
for packet in maybe_emit_argument_delta(tc_map, delta, pl, parsers=parsers):
obj = packet.obj
assert isinstance(obj, ToolCallArgumentDelta)
for value in obj.argument_deltas.values():
emitted.append(value)
return emitted
class TestMaybeEmitArgumentDeltaGuards:
"""Tests for conditions that cause no packet to be emitted."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_tool_does_not_opt_in(
self, mock_get_tool: MagicMock
) -> None:
"""Tools that return False from should_emit_argument_deltas emit nothing."""
mock_get_tool.return_value = _mock_tool_class(emit=False)
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
}
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_tool_class_unknown(
self, mock_get_tool: MagicMock
) -> None:
mock_get_tool.return_value = None
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "unknown", "arguments": '{"code": "x'}
}
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_no_argument_fragment(
self, mock_get_tool: MagicMock
) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
}
assert _collect(tc_map, _make_tool_call_delta(arguments=None)) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_key_value_not_started(
self, mock_get_tool: MagicMock
) -> None:
"""Key exists in JSON but its string value hasn't begun yet."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code":'}
}
assert _collect(tc_map, _make_tool_call_delta(arguments=":")) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_before_any_key(self, mock_get_tool: MagicMock) -> None:
"""Only the opening brace has arrived — no key to stream yet."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": "{"}
}
assert _collect(tc_map, _make_tool_call_delta(arguments="{")) == []
class TestMaybeEmitArgumentDeltaBasic:
"""Tests for correct packet content and incremental emission."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_emits_packet_with_correct_fields(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "', "print(1)", '"}']
pl = _make_placement()
parsers: dict[int, Parser] = {}
all_packets = []
for frag in fragments:
tc_map[0]["arguments"] += frag
packets = _collect(
tc_map, _make_tool_call_delta(arguments=frag), pl, parsers
)
all_packets.extend(packets)
assert len(all_packets) >= 1
# Verify packet structure
obj = all_packets[0].obj
assert isinstance(obj, ToolCallArgumentDelta)
assert obj.tool_type == "python"
# All emitted content should reconstruct the value
full_code = ""
for p in all_packets:
assert isinstance(p.obj, ToolCallArgumentDelta)
if "code" in p.obj.argument_deltas:
full_code += p.obj.argument_deltas["code"]
assert full_code == "print(1)"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_emits_only_new_content_on_subsequent_call(
self, mock_get_tool: MagicMock
) -> None:
"""After a first emission, subsequent calls emit only the diff."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
parsers: dict[int, Parser] = {}
pl = _make_placement()
# First fragment opens the string
tc_map[0]["arguments"] = '{"code": "abc'
packets_1 = _collect(
tc_map, _make_tool_call_delta(arguments='{"code": "abc'), pl, parsers
)
code_1 = ""
for p in packets_1:
assert isinstance(p.obj, ToolCallArgumentDelta)
code_1 += p.obj.argument_deltas.get("code", "")
assert code_1 == "abc"
# Second fragment appends more
tc_map[0]["arguments"] = '{"code": "abcdef'
packets_2 = _collect(
tc_map, _make_tool_call_delta(arguments="def"), pl, parsers
)
code_2 = ""
for p in packets_2:
assert isinstance(p.obj, ToolCallArgumentDelta)
code_2 += p.obj.argument_deltas.get("code", "")
assert code_2 == "def"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_handles_multiple_keys_sequentially(self, mock_get_tool: MagicMock) -> None:
"""When a second key starts, emissions switch to that key."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "x',
'", "output": "hello',
'"}',
]
emitted = _stream_fragments(fragments, tc_map)
full = "".join(emitted)
assert "x" in full
assert "hello" in full
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_delta_spans_key_boundary(self, mock_get_tool: MagicMock) -> None:
"""A single delta contains the end of one value and the start of the next key."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "x',
'y", "lang": "py',
'"}',
]
emitted = _stream_fragments(fragments, tc_map)
full = "".join(emitted)
assert "xy" in full
assert "py" in full
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_empty_value_emits_nothing(self, mock_get_tool: MagicMock) -> None:
"""An empty string value has nothing to emit."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
# Opening quote just arrived, value is empty
tc_map[0]["arguments"] = '{"code": "'
packets = _collect(tc_map, _make_tool_call_delta(arguments='{"code": "'))
# No string content yet, so either no packet or empty deltas
for p in packets:
assert isinstance(p.obj, ToolCallArgumentDelta)
assert p.obj.argument_deltas.get("code", "") == ""
class TestMaybeEmitArgumentDeltaDecoding:
"""Tests verifying that JSON escape sequences are properly decoded."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_newlines(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "line1\\nline2"}']
emitted = _stream_fragments(fragments, tc_map)
assert "".join(emitted) == "line1\nline2"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_tabs(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "\\tindented"}']
emitted = _stream_fragments(fragments, tc_map)
assert "".join(emitted) == "\tindented"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_escaped_quotes(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "say \\"hi\\""}']
emitted = _stream_fragments(fragments, tc_map)
assert "".join(emitted) == 'say "hi"'
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_escaped_backslashes(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "path\\\\dir"}']
emitted = _stream_fragments(fragments, tc_map)
assert "".join(emitted) == "path\\dir"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_unicode_escape(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "\\u0041"}']
emitted = _stream_fragments(fragments, tc_map)
assert "".join(emitted) == "A"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_incomplete_escape_at_end_decoded_on_next_chunk(
self, mock_get_tool: MagicMock
) -> None:
"""A trailing backslash (incomplete escape) is completed in the next chunk."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "hello\\', 'n"}']
emitted = _stream_fragments(fragments, tc_map)
assert "".join(emitted) == "hello\n"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_incomplete_unicode_escape_completed_on_next_chunk(
self, mock_get_tool: MagicMock
) -> None:
"""A partial \\uXX sequence is completed in the next chunk."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "hello\\u00', '41"}']
emitted = _stream_fragments(fragments, tc_map)
assert "".join(emitted) == "helloA"
class TestArgumentDeltaStreamingE2E:
"""Simulates realistic sequences of LLM argument deltas to verify
the full pipeline produces correct decoded output."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_realistic_python_code_streaming(self, mock_get_tool: MagicMock) -> None:
"""Streams: {"code": "print('hello')\\nprint('world')"}"""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"',
"code",
'": "',
"print(",
"'hello')",
"\\n",
"print(",
"'world')",
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == "print('hello')\nprint('world')"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_streaming_with_tabs_and_newlines(self, mock_get_tool: MagicMock) -> None:
"""Streams code with tabs and newlines."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "',
"if True:",
"\\n",
"\\t",
"pass",
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == "if True:\n\tpass"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_split_escape_sequence(self, mock_get_tool: MagicMock) -> None:
"""An escape sequence split across two fragments (backslash in one,
'n' in the next) should still decode correctly."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "hello',
"\\",
"n",
'world"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == "hello\nworld"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_multiple_newlines_and_indentation(self, mock_get_tool: MagicMock) -> None:
"""Streams a multi-line function with multiple escape sequences."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "',
"def foo():",
"\\n",
"\\t",
"x = 1",
"\\n",
"\\t",
"return x",
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == "def foo():\n\tx = 1\n\treturn x"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_two_keys_streamed_sequentially(self, mock_get_tool: MagicMock) -> None:
"""Streams code first, then a second key (language) — both decoded."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "',
"x = 1",
'", "language": "',
"python",
'"}',
]
emitted = _stream_fragments(fragments, tc_map)
# Should have emissions for both keys
full = "".join(emitted)
assert "x = 1" in full
assert "python" in full
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_code_containing_dict_literal(self, mock_get_tool: MagicMock) -> None:
"""Python code like `x = {"key": "val"}` contains JSON-like patterns.
The escaped quotes inside the *outer* JSON value should prevent the
inner `"key":` from being mistaken for a top-level JSON key."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
# The LLM sends: {"code": "x = {\"key\": \"val\"}"}
# The inner quotes are escaped as \" in the JSON value.
fragments = [
'{"code": "',
"x = {",
'\\"key\\"',
": ",
'\\"val\\"',
"}",
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == 'x = {"key": "val"}'
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_code_with_colon_in_value(self, mock_get_tool: MagicMock) -> None:
"""Colons inside the string value should not confuse key detection."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "',
"url = ",
'\\"https://example.com\\"',
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == 'url = "https://example.com"'
class TestMaybeEmitArgumentDeltaEdgeCases:
"""Edge cases not covered by the standard test classes."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_function_is_none(self, mock_get_tool: MagicMock) -> None:
"""Some delta chunks have function=None (e.g. role-only deltas)."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
}
delta = _make_tool_call_delta(arguments=None, function_is_none=True)
assert _collect(tc_map, delta) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_multiple_concurrent_tool_calls(self, mock_get_tool: MagicMock) -> None:
"""Two tool calls streaming at different indices in parallel."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""},
1: {"id": "tc_2", "name": "python", "arguments": ""},
}
parsers: dict[int, Parser] = {}
pl = _make_placement()
# Feed full JSON to index 0
tc_map[0]["arguments"] = '{"code": "aaa"}'
packets_0 = _collect(
tc_map,
_make_tool_call_delta(index=0, arguments='{"code": "aaa"}'),
pl,
parsers,
)
code_0 = ""
for p in packets_0:
assert isinstance(p.obj, ToolCallArgumentDelta)
code_0 += p.obj.argument_deltas.get("code", "")
assert code_0 == "aaa"
# Feed full JSON to index 1
tc_map[1]["arguments"] = '{"code": "bbb"}'
packets_1 = _collect(
tc_map,
_make_tool_call_delta(index=1, arguments='{"code": "bbb"}'),
pl,
parsers,
)
code_1 = ""
for p in packets_1:
assert isinstance(p.obj, ToolCallArgumentDelta)
code_1 += p.obj.argument_deltas.get("code", "")
assert code_1 == "bbb"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_delta_with_four_arguments(self, mock_get_tool: MagicMock) -> None:
"""A single delta contains four complete key-value pairs."""
mock_get_tool.return_value = _mock_tool_class()
full = '{"a": "one", "b": "two", "c": "three", "d": "four"}'
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
tc_map[0]["arguments"] = full
parsers: dict[int, Parser] = {}
packets = _collect(
tc_map, _make_tool_call_delta(arguments=full), parsers=parsers
)
# Collect all argument deltas across packets
all_deltas: dict[str, str] = {}
for p in packets:
assert isinstance(p.obj, ToolCallArgumentDelta)
for k, v in p.obj.argument_deltas.items():
all_deltas[k] = all_deltas.get(k, "") + v
assert all_deltas == {
"a": "one",
"b": "two",
"c": "three",
"d": "four",
}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_delta_on_second_arg_after_first_complete(
self, mock_get_tool: MagicMock
) -> None:
"""First argument is fully complete; delta only adds to the second."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "print(1)", "lang": "py',
'"}',
]
emitted = _stream_fragments(fragments, tc_map)
full = "".join(emitted)
assert "print(1)" in full
assert "py" in full
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_non_string_values_skipped(self, mock_get_tool: MagicMock) -> None:
"""Non-string values (numbers, booleans, null) are skipped — they are
available in the final tool-call kickoff packet. String arguments
following them are still emitted."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"timeout": 30, "code": "hello"}']
emitted = _stream_fragments(fragments, tc_map)
full = "".join(emitted)
assert full == "hello"

View File

@@ -0,0 +1,325 @@
"""Unit tests for SharepointConnector._fetch_site_pages error handling.
Covers 404 handling (classic sites / no modern pages) and 400
canvasLayout fallback (corrupt pages causing $expand=canvasLayout to
fail on the LIST endpoint).
"""
from __future__ import annotations
import json
from typing import Any
import pytest
from requests import Response
from requests.exceptions import HTTPError
from onyx.connectors.sharepoint.connector import GRAPH_INVALID_REQUEST_CODE
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.connectors.sharepoint.connector import SiteDescriptor
SITE_URL = "https://tenant.sharepoint.com/sites/ClassicSite"
FAKE_SITE_ID = "tenant.sharepoint.com,abc123,def456"
PAGES_COLLECTION = f"https://graph.microsoft.com/v1.0/sites/{FAKE_SITE_ID}/pages"
SITE_PAGES_BASE = f"{PAGES_COLLECTION}/microsoft.graph.sitePage"
def _site_descriptor() -> SiteDescriptor:
return SiteDescriptor(url=SITE_URL, drive_name=None, folder_path=None)
def _make_http_error(
status_code: int,
error_code: str = "itemNotFound",
message: str = "Item not found",
) -> HTTPError:
body = {"error": {"code": error_code, "message": message}}
response = Response()
response.status_code = status_code
response._content = json.dumps(body).encode()
response.headers["Content-Type"] = "application/json"
return HTTPError(response=response)
def _setup_connector(
monkeypatch: pytest.MonkeyPatch, # noqa: ARG001
) -> SharepointConnector:
"""Create a connector with the graph client and site resolution mocked."""
connector = SharepointConnector(sites=[SITE_URL])
connector.graph_api_base = "https://graph.microsoft.com/v1.0"
mock_sites = type(
"FakeSites",
(),
{
"get_by_url": staticmethod(
lambda url: type( # noqa: ARG005
"Q",
(),
{
"execute_query": lambda self: None, # noqa: ARG005
"id": FAKE_SITE_ID,
},
)()
),
},
)()
connector._graph_client = type("FakeGraphClient", (), {"sites": mock_sites})()
return connector
def _patch_graph_api_get_json(
monkeypatch: pytest.MonkeyPatch,
fake_fn: Any,
) -> None:
monkeypatch.setattr(SharepointConnector, "_graph_api_get_json", fake_fn)
class TestFetchSitePages404:
def test_404_yields_no_pages(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""A 404 from the Pages API should result in zero yielded pages."""
connector = _setup_connector(monkeypatch)
def fake_get_json(
self: SharepointConnector, # noqa: ARG001
url: str, # noqa: ARG001
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
raise _make_http_error(404)
_patch_graph_api_get_json(monkeypatch, fake_get_json)
pages = list(connector._fetch_site_pages(_site_descriptor()))
assert pages == []
def test_404_does_not_raise(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""A 404 must not propagate as an exception."""
connector = _setup_connector(monkeypatch)
def fake_get_json(
self: SharepointConnector, # noqa: ARG001
url: str, # noqa: ARG001
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
raise _make_http_error(404)
_patch_graph_api_get_json(monkeypatch, fake_get_json)
for _ in connector._fetch_site_pages(_site_descriptor()):
pass
def test_non_404_http_error_still_raises(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Non-404 HTTP errors (e.g. 403) must still propagate."""
connector = _setup_connector(monkeypatch)
def fake_get_json(
self: SharepointConnector, # noqa: ARG001
url: str, # noqa: ARG001
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
raise _make_http_error(403)
_patch_graph_api_get_json(monkeypatch, fake_get_json)
with pytest.raises(HTTPError):
list(connector._fetch_site_pages(_site_descriptor()))
def test_successful_fetch_yields_pages(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""When the API succeeds, pages should be yielded normally."""
connector = _setup_connector(monkeypatch)
fake_page = {
"id": "page-1",
"title": "Hello World",
"webUrl": f"{SITE_URL}/SitePages/Hello.aspx",
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
}
def fake_get_json(
self: SharepointConnector, # noqa: ARG001
url: str, # noqa: ARG001
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
return {"value": [fake_page]}
_patch_graph_api_get_json(monkeypatch, fake_get_json)
pages = list(connector._fetch_site_pages(_site_descriptor()))
assert len(pages) == 1
assert pages[0]["id"] == "page-1"
def test_404_on_second_page_stops_pagination(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""If the first API page succeeds but a nextLink returns 404,
already-yielded pages are kept and iteration stops cleanly."""
connector = _setup_connector(monkeypatch)
call_count = 0
first_page = {
"id": "page-1",
"title": "First",
"webUrl": f"{SITE_URL}/SitePages/First.aspx",
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
}
def fake_get_json(
self: SharepointConnector, # noqa: ARG001
url: str, # noqa: ARG001
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
nonlocal call_count
call_count += 1
if call_count == 1:
return {
"value": [first_page],
"@odata.nextLink": "https://graph.microsoft.com/next",
}
raise _make_http_error(404)
_patch_graph_api_get_json(monkeypatch, fake_get_json)
pages = list(connector._fetch_site_pages(_site_descriptor()))
assert len(pages) == 1
assert pages[0]["id"] == "page-1"
class TestFetchSitePages400Fallback:
"""When $expand=canvasLayout on the LIST endpoint returns 400
invalidRequest, _fetch_site_pages should fall back to listing
without expansion, then expanding each page individually."""
GOOD_PAGE: dict[str, Any] = {
"id": "good-1",
"name": "Good.aspx",
"title": "Good Page",
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
}
BAD_PAGE: dict[str, Any] = {
"id": "bad-1",
"name": "Bad.aspx",
"title": "Bad Page",
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
}
GOOD_PAGE_EXPANDED: dict[str, Any] = {
**GOOD_PAGE,
"canvasLayout": {"horizontalSections": []},
}
def test_fallback_expands_good_pages_individually(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""On 400 from the LIST expand, the connector should list without
expand, then GET each page individually with $expand=canvasLayout."""
connector = _setup_connector(monkeypatch)
good_page = self.GOOD_PAGE
bad_page = self.BAD_PAGE
good_page_expanded = self.GOOD_PAGE_EXPANDED
def fake_get_json(
self: SharepointConnector, # noqa: ARG001
url: str,
params: dict[str, str] | None = None,
) -> dict[str, Any]:
if url == SITE_PAGES_BASE and params == {"$expand": "canvasLayout"}:
raise _make_http_error(
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
)
if url == SITE_PAGES_BASE and params is None:
return {"value": [good_page, bad_page]}
expand_params = {"$expand": "canvasLayout"}
if url == f"{PAGES_COLLECTION}/good-1/microsoft.graph.sitePage":
assert params == expand_params, f"Expected $expand params, got {params}"
return good_page_expanded
if url == f"{PAGES_COLLECTION}/bad-1/microsoft.graph.sitePage":
assert params == expand_params, f"Expected $expand params, got {params}"
raise _make_http_error(
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
)
raise AssertionError(f"Unexpected call: {url} {params}")
_patch_graph_api_get_json(monkeypatch, fake_get_json)
pages = list(connector._fetch_site_pages(_site_descriptor()))
assert len(pages) == 2
assert pages[0].get("canvasLayout") is not None
assert pages[1].get("canvasLayout") is None
assert pages[1]["id"] == "bad-1"
def test_mid_pagination_400_does_not_duplicate(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""If the first paginated batch succeeds but a later nextLink
returns 400, pages from the first batch must not be re-yielded
by the fallback."""
connector = _setup_connector(monkeypatch)
good_page = self.GOOD_PAGE
good_page_expanded = self.GOOD_PAGE_EXPANDED
bad_page = self.BAD_PAGE
second_page = {
"id": "page-2",
"name": "Second.aspx",
"title": "Second Page",
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
}
next_link = "https://graph.microsoft.com/v1.0/next-page-link"
def fake_get_json(
self: SharepointConnector, # noqa: ARG001
url: str,
params: dict[str, str] | None = None,
) -> dict[str, Any]:
if url == SITE_PAGES_BASE and params == {"$expand": "canvasLayout"}:
return {
"value": [good_page],
"@odata.nextLink": next_link,
}
if url == next_link:
raise _make_http_error(
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
)
if url == SITE_PAGES_BASE and params is None:
return {"value": [good_page, bad_page, second_page]}
expand_params = {"$expand": "canvasLayout"}
if url == f"{PAGES_COLLECTION}/good-1/microsoft.graph.sitePage":
assert params == expand_params, f"Expected $expand params, got {params}"
return good_page_expanded
if url == f"{PAGES_COLLECTION}/bad-1/microsoft.graph.sitePage":
assert params == expand_params, f"Expected $expand params, got {params}"
raise _make_http_error(
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
)
if url == f"{PAGES_COLLECTION}/page-2/microsoft.graph.sitePage":
assert params == expand_params, f"Expected $expand params, got {params}"
return {**second_page, "canvasLayout": {"horizontalSections": []}}
raise AssertionError(f"Unexpected call: {url} {params}")
_patch_graph_api_get_json(monkeypatch, fake_get_json)
pages = list(connector._fetch_site_pages(_site_descriptor()))
ids = [p["id"] for p in pages]
assert ids == ["good-1", "bad-1", "page-2"]
def test_non_invalid_request_400_still_raises(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""A 400 with a different error code (not invalidRequest) should
propagate, not trigger the fallback."""
connector = _setup_connector(monkeypatch)
def fake_get_json(
self: SharepointConnector, # noqa: ARG001
url: str, # noqa: ARG001
params: dict[str, str] | None = None, # noqa: ARG001
) -> dict[str, Any]:
raise _make_http_error(400, "badRequest", "Something else went wrong")
_patch_graph_api_get_json(monkeypatch, fake_get_json)
with pytest.raises(HTTPError):
list(connector._fetch_site_pages(_site_descriptor()))

View File

@@ -7,6 +7,7 @@ import pytest
from onyx.db.llm import sync_model_configurations
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import SyncModelEntry
class TestSyncModelConfigurations:
@@ -25,18 +26,18 @@ class TestSyncModelConfigurations:
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
):
models = [
{
"name": "gpt-4",
"display_name": "GPT-4",
"max_input_tokens": 128000,
"supports_image_input": True,
},
{
"name": "gpt-4o",
"display_name": "GPT-4o",
"max_input_tokens": 128000,
"supports_image_input": True,
},
SyncModelEntry(
name="gpt-4",
display_name="GPT-4",
max_input_tokens=128000,
supports_image_input=True,
),
SyncModelEntry(
name="gpt-4o",
display_name="GPT-4o",
max_input_tokens=128000,
supports_image_input=True,
),
]
result = sync_model_configurations(
@@ -67,18 +68,18 @@ class TestSyncModelConfigurations:
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
):
models = [
{
"name": "gpt-4", # Existing - should be skipped
"display_name": "GPT-4",
"max_input_tokens": 128000,
"supports_image_input": True,
},
{
"name": "gpt-4o", # New - should be inserted
"display_name": "GPT-4o",
"max_input_tokens": 128000,
"supports_image_input": True,
},
SyncModelEntry(
name="gpt-4", # Existing - should be skipped
display_name="GPT-4",
max_input_tokens=128000,
supports_image_input=True,
),
SyncModelEntry(
name="gpt-4o", # New - should be inserted
display_name="GPT-4o",
max_input_tokens=128000,
supports_image_input=True,
),
]
result = sync_model_configurations(
@@ -105,12 +106,12 @@ class TestSyncModelConfigurations:
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
):
models = [
{
"name": "gpt-4", # Already exists
"display_name": "GPT-4",
"max_input_tokens": 128000,
"supports_image_input": True,
},
SyncModelEntry(
name="gpt-4", # Already exists
display_name="GPT-4",
max_input_tokens=128000,
supports_image_input=True,
),
]
result = sync_model_configurations(
@@ -131,7 +132,7 @@ class TestSyncModelConfigurations:
sync_model_configurations(
db_session=mock_session,
provider_name="nonexistent",
models=[{"name": "model", "display_name": "Model"}],
models=[SyncModelEntry(name="model", display_name="Model")],
)
def test_handles_missing_optional_fields(self) -> None:
@@ -145,12 +146,12 @@ class TestSyncModelConfigurations:
with patch(
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
):
# Model with only required fields
# Model with only required fields (max_input_tokens and supports_image_input default)
models = [
{
"name": "model-1",
# No display_name, max_input_tokens, or supports_image_input
},
SyncModelEntry(
name="model-1",
display_name="Model 1",
),
]
result = sync_model_configurations(

View File

@@ -1,196 +0,0 @@
import io
import openpyxl
from onyx.file_processing.extract_file_text import xlsx_to_text
def _make_xlsx(sheets: dict[str, list[list[str]]]) -> io.BytesIO:
"""Create an in-memory xlsx file from a dict of sheet_name -> matrix of strings."""
wb = openpyxl.Workbook()
if wb.active is not None:
wb.remove(wb.active)
for sheet_name, rows in sheets.items():
ws = wb.create_sheet(title=sheet_name)
for row in rows:
ws.append(row)
buf = io.BytesIO()
wb.save(buf)
buf.seek(0)
return buf
class TestXlsxToText:
def test_single_sheet_basic(self) -> None:
xlsx = _make_xlsx(
{
"Sheet1": [
["Name", "Age"],
["Alice", "30"],
["Bob", "25"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
assert len(lines) == 3
assert "Name" in lines[0]
assert "Age" in lines[0]
assert "Alice" in lines[1]
assert "30" in lines[1]
assert "Bob" in lines[2]
def test_multiple_sheets_separated(self) -> None:
xlsx = _make_xlsx(
{
"Sheet1": [["a", "b"]],
"Sheet2": [["c", "d"]],
}
)
result = xlsx_to_text(xlsx)
# TEXT_SECTION_SEPARATOR is "\n\n"
assert "\n\n" in result
parts = result.split("\n\n")
assert any("a" in p for p in parts)
assert any("c" in p for p in parts)
def test_empty_cells(self) -> None:
xlsx = _make_xlsx(
{
"Sheet1": [
["a", "", "b"],
["", "c", ""],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
assert len(lines) == 2
def test_commas_in_cells_are_quoted(self) -> None:
"""Cells containing commas should be quoted in CSV output."""
xlsx = _make_xlsx(
{
"Sheet1": [
["hello, world", "normal"],
]
}
)
result = xlsx_to_text(xlsx)
assert '"hello, world"' in result
def test_empty_workbook(self) -> None:
xlsx = _make_xlsx({"Sheet1": []})
result = xlsx_to_text(xlsx)
assert result.strip() == ""
def test_long_empty_row_run_capped(self) -> None:
"""Runs of >2 empty rows should be capped to 2."""
xlsx = _make_xlsx(
{
"Sheet1": [
["header"],
[""],
[""],
[""],
[""],
["data"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
# 4 empty rows capped to 2, so: header + 2 empty + data = 4 lines
assert len(lines) == 4
assert "header" in lines[0]
assert "data" in lines[-1]
def test_long_empty_col_run_capped(self) -> None:
"""Runs of >2 empty columns should be capped to 2."""
xlsx = _make_xlsx(
{
"Sheet1": [
["a", "", "", "", "b"],
["c", "", "", "", "d"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
assert len(lines) == 2
# Each row should have 4 fields (a + 2 empty + b), not 5
# csv format: a,,,b (3 commas = 4 fields)
first_line = lines[0].strip()
# Count commas to verify column reduction
assert first_line.count(",") == 3
def test_short_empty_runs_kept(self) -> None:
"""Runs of <=2 empty rows/cols should be preserved."""
xlsx = _make_xlsx(
{
"Sheet1": [
["a", "b"],
["", ""],
["", ""],
["c", "d"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
# All 4 rows preserved (2 empty rows <= threshold)
assert len(lines) == 4
def test_bad_zip_file_returns_empty(self) -> None:
bad_file = io.BytesIO(b"not a zip file")
result = xlsx_to_text(bad_file, file_name="test.xlsx")
assert result == ""
def test_bad_zip_tilde_file_returns_empty(self) -> None:
bad_file = io.BytesIO(b"not a zip file")
result = xlsx_to_text(bad_file, file_name="~$temp.xlsx")
assert result == ""
def test_large_sparse_sheet(self) -> None:
"""A sheet with data, a big empty gap, and more data — gap is capped to 2."""
rows: list[list[str]] = [["row1_data"]]
rows.extend([[""] for _ in range(10)])
rows.append(["row2_data"])
xlsx = _make_xlsx({"Sheet1": rows})
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
# 10 empty rows capped to 2: row1_data + 2 empty + row2_data = 4
assert len(lines) == 4
assert "row1_data" in lines[0]
assert "row2_data" in lines[-1]
def test_quotes_in_cells(self) -> None:
"""Cells containing quotes should be properly escaped."""
xlsx = _make_xlsx(
{
"Sheet1": [
['say "hello"', "normal"],
]
}
)
result = xlsx_to_text(xlsx)
# csv.writer escapes quotes by doubling them
assert '""hello""' in result
def test_each_row_is_separate_line(self) -> None:
"""Each row should produce its own line (regression for writerow vs writerows)."""
xlsx = _make_xlsx(
{
"Sheet1": [
["r1c1", "r1c2"],
["r2c1", "r2c2"],
["r3c1", "r3c2"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
assert len(lines) == 3
assert "r1c1" in lines[0] and "r1c2" in lines[0]
assert "r2c1" in lines[1] and "r2c2" in lines[1]
assert "r3c1" in lines[2] and "r3c2" in lines[2]

View File

@@ -26,14 +26,6 @@ class TestIsTrueOpenAIModel:
"""Test that real OpenAI GPT-4o-mini model is correctly identified."""
assert is_true_openai_model(LlmProviderNames.OPENAI, "gpt-4o-mini") is True
def test_real_openai_o1_preview(self) -> None:
"""Test that real OpenAI o1-preview reasoning model is correctly identified."""
assert is_true_openai_model(LlmProviderNames.OPENAI, "o1-preview") is True
def test_real_openai_o1_mini(self) -> None:
"""Test that real OpenAI o1-mini reasoning model is correctly identified."""
assert is_true_openai_model(LlmProviderNames.OPENAI, "o1-mini") is True
def test_openai_with_provider_prefix(self) -> None:
"""Test that OpenAI model with provider prefix is correctly identified."""
assert is_true_openai_model(LlmProviderNames.OPENAI, "openai/gpt-4") is False

View File

@@ -0,0 +1,204 @@
"""Tests for Slack channel reference resolution and tag filtering
in handle_regular_answer.py."""
from unittest.mock import MagicMock
from slack_sdk.errors import SlackApiError
from onyx.context.search.models import Tag
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
from onyx.onyxbot.slack.handlers.handle_regular_answer import resolve_channel_references
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _mock_client_with_channels(
channel_map: dict[str, str],
) -> MagicMock:
"""Return a mock WebClient where conversations_info resolves IDs to names."""
client = MagicMock()
def _conversations_info(channel: str) -> MagicMock:
if channel in channel_map:
resp = MagicMock()
resp.validate = MagicMock()
resp.__getitem__ = lambda _self, key: {
"channel": {
"name": channel_map[channel],
"is_im": False,
"is_mpim": False,
}
}[key]
return resp
raise SlackApiError("channel_not_found", response=MagicMock())
client.conversations_info = _conversations_info
return client
def _mock_logger() -> MagicMock:
return MagicMock()
# ---------------------------------------------------------------------------
# SLACK_CHANNEL_REF_PATTERN regex tests
# ---------------------------------------------------------------------------
class TestSlackChannelRefPattern:
def test_matches_bare_channel_id(self) -> None:
matches = SLACK_CHANNEL_REF_PATTERN.findall("<#C097NBWMY8Y>")
assert matches == [("C097NBWMY8Y", "")]
def test_matches_channel_id_with_name(self) -> None:
matches = SLACK_CHANNEL_REF_PATTERN.findall("<#C097NBWMY8Y|eng-infra>")
assert matches == [("C097NBWMY8Y", "eng-infra")]
def test_matches_multiple_channels(self) -> None:
msg = "compare <#C111AAA> and <#C222BBB|general>"
matches = SLACK_CHANNEL_REF_PATTERN.findall(msg)
assert len(matches) == 2
assert ("C111AAA", "") in matches
assert ("C222BBB", "general") in matches
def test_no_match_on_plain_text(self) -> None:
matches = SLACK_CHANNEL_REF_PATTERN.findall("no channels here")
assert matches == []
def test_no_match_on_user_mention(self) -> None:
matches = SLACK_CHANNEL_REF_PATTERN.findall("<@U12345>")
assert matches == []
# ---------------------------------------------------------------------------
# resolve_channel_references tests
# ---------------------------------------------------------------------------
class TestResolveChannelReferences:
def test_resolves_bare_channel_id_via_api(self) -> None:
client = _mock_client_with_channels({"C097NBWMY8Y": "eng-infra"})
logger = _mock_logger()
message, tags = resolve_channel_references(
message="summary of <#C097NBWMY8Y> this week",
client=client,
logger=logger,
)
assert message == "summary of #eng-infra this week"
assert len(tags) == 1
assert tags[0] == Tag(tag_key="Channel", tag_value="eng-infra")
def test_uses_name_from_pipe_format_without_api_call(self) -> None:
client = MagicMock()
logger = _mock_logger()
message, tags = resolve_channel_references(
message="check <#C097NBWMY8Y|eng-infra> for updates",
client=client,
logger=logger,
)
assert message == "check #eng-infra for updates"
assert tags == [Tag(tag_key="Channel", tag_value="eng-infra")]
# Should NOT have called the API since name was in the markup
client.conversations_info.assert_not_called()
def test_multiple_channels(self) -> None:
client = _mock_client_with_channels(
{
"C111AAA": "eng-infra",
"C222BBB": "eng-general",
}
)
logger = _mock_logger()
message, tags = resolve_channel_references(
message="compare <#C111AAA> and <#C222BBB>",
client=client,
logger=logger,
)
assert "#eng-infra" in message
assert "#eng-general" in message
assert "<#" not in message
assert len(tags) == 2
tag_values = {t.tag_value for t in tags}
assert tag_values == {"eng-infra", "eng-general"}
def test_no_channel_references_returns_unchanged(self) -> None:
client = MagicMock()
logger = _mock_logger()
message, tags = resolve_channel_references(
message="just a normal message with no channels",
client=client,
logger=logger,
)
assert message == "just a normal message with no channels"
assert tags == []
def test_api_failure_skips_channel_gracefully(self) -> None:
# Client that fails for all channel lookups
client = _mock_client_with_channels({})
logger = _mock_logger()
message, tags = resolve_channel_references(
message="check <#CBADID123>",
client=client,
logger=logger,
)
# Message should remain unchanged for the failed channel
assert "<#CBADID123>" in message
assert tags == []
logger.warning.assert_called_once()
def test_partial_failure_resolves_what_it_can(self) -> None:
# Only one of two channels resolves
client = _mock_client_with_channels({"C111AAA": "eng-infra"})
logger = _mock_logger()
message, tags = resolve_channel_references(
message="compare <#C111AAA> and <#CBADID123>",
client=client,
logger=logger,
)
assert "#eng-infra" in message
assert "<#CBADID123>" in message # failed one stays raw
assert len(tags) == 1
assert tags[0].tag_value == "eng-infra"
def test_duplicate_channel_produces_single_tag(self) -> None:
client = _mock_client_with_channels({"C111AAA": "eng-infra"})
logger = _mock_logger()
message, tags = resolve_channel_references(
message="summarize <#C111AAA> and compare with <#C111AAA>",
client=client,
logger=logger,
)
assert message == "summarize #eng-infra and compare with #eng-infra"
assert len(tags) == 1
assert tags[0].tag_value == "eng-infra"
def test_mixed_pipe_and_bare_formats(self) -> None:
client = _mock_client_with_channels({"C222BBB": "random"})
logger = _mock_logger()
message, tags = resolve_channel_references(
message="see <#C111AAA|eng-infra> and <#C222BBB>",
client=client,
logger=logger,
)
assert "#eng-infra" in message
assert "#random" in message
assert len(tags) == 2

View File

@@ -1,43 +0,0 @@
"""Unit tests for _get_user_access_info helper function.
These tests mock all database operations and don't require a real database.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from sqlalchemy.orm import Session
from onyx.server.features.hierarchy.api import _get_user_access_info
def test_get_user_access_info_returns_email_and_groups() -> None:
"""_get_user_access_info returns the user's email and external group IDs."""
mock_user = MagicMock()
mock_user.email = "test@example.com"
mock_db_session = MagicMock(spec=Session)
with patch(
"onyx.server.features.hierarchy.api.get_user_external_group_ids",
return_value=["group1", "group2"],
):
email, groups = _get_user_access_info(mock_user, mock_db_session)
assert email == "test@example.com"
assert groups == ["group1", "group2"]
def test_get_user_access_info_with_no_groups() -> None:
"""User with no external groups returns empty list."""
mock_user = MagicMock()
mock_user.email = "solo@example.com"
mock_db_session = MagicMock(spec=Session)
with patch(
"onyx.server.features.hierarchy.api.get_user_external_group_ids",
return_value=[],
):
email, groups = _get_user_access_info(mock_user, mock_db_session)
assert email == "solo@example.com"
assert groups == []

View File

@@ -1,15 +1,19 @@
"""Tests for LLM model fetch endpoints.
These tests verify the full request/response flow for fetching models
from dynamic providers (Ollama, OpenRouter), including the
from dynamic providers (Ollama, OpenRouter, Litellm), including the
sync-to-DB behavior when provider_name is specified.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
from onyx.error_handling.exceptions import OnyxError
from onyx.server.manage.llm.models import LitellmFinalModelResponse
from onyx.server.manage.llm.models import LitellmModelsRequest
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
from onyx.server.manage.llm.models import LMStudioModelsRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
@@ -614,3 +618,283 @@ class TestGetLMStudioAvailableModels:
request = LMStudioModelsRequest(api_base="http://localhost:1234")
with pytest.raises(OnyxError):
get_lm_studio_available_models(request, MagicMock(), mock_session)
class TestGetLitellmAvailableModels:
"""Tests for the Litellm proxy model fetch endpoint."""
@pytest.fixture
def mock_litellm_response(self) -> dict:
"""Mock response from Litellm /v1/models endpoint."""
return {
"data": [
{
"id": "gpt-4o",
"object": "model",
"created": 1700000000,
"owned_by": "openai",
},
{
"id": "claude-3-5-sonnet",
"object": "model",
"created": 1700000001,
"owned_by": "anthropic",
},
{
"id": "gemini-pro",
"object": "model",
"created": 1700000002,
"owned_by": "google",
},
]
}
def test_returns_model_list(self, mock_litellm_response: dict) -> None:
"""Test that endpoint returns properly formatted model list."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_litellm_response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
results = get_litellm_available_models(request, MagicMock(), mock_session)
assert len(results) == 3
assert all(isinstance(r, LitellmFinalModelResponse) for r in results)
def test_model_fields_parsed_correctly(self, mock_litellm_response: dict) -> None:
"""Test that provider_name and model_name are correctly extracted."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_litellm_response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
results = get_litellm_available_models(request, MagicMock(), mock_session)
gpt = next(r for r in results if r.model_name == "gpt-4o")
assert gpt.provider_name == "openai"
claude = next(r for r in results if r.model_name == "claude-3-5-sonnet")
assert claude.provider_name == "anthropic"
def test_results_sorted_by_model_name(self, mock_litellm_response: dict) -> None:
"""Test that results are alphabetically sorted by model_name."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_litellm_response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
results = get_litellm_available_models(request, MagicMock(), mock_session)
model_names = [r.model_name for r in results]
assert model_names == sorted(model_names, key=str.lower)
def test_empty_data_raises_onyx_error(self) -> None:
"""Test that empty model list raises OnyxError."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = {"data": []}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError, match="No models found"):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_missing_data_key_raises_onyx_error(self) -> None:
"""Test that response without 'data' key raises OnyxError."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = {}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_skips_unparseable_entries(self) -> None:
"""Test that malformed model entries are skipped without failing."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
response_with_bad_entry = {
"data": [
{
"id": "gpt-4o",
"object": "model",
"created": 1700000000,
"owned_by": "openai",
},
# Missing required fields
{"bad_field": "bad_value"},
]
}
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = response_with_bad_entry
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
results = get_litellm_available_models(request, MagicMock(), mock_session)
assert len(results) == 1
assert results[0].model_name == "gpt-4o"
def test_all_entries_unparseable_raises_onyx_error(self) -> None:
"""Test that OnyxError is raised when all entries fail to parse."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
response_all_bad = {
"data": [
{"bad_field": "bad_value"},
{"another_bad": 123},
]
}
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = response_all_bad
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError, match="No compatible models"):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_api_base_trailing_slash_handled(self) -> None:
"""Test that trailing slashes in api_base are handled correctly."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
mock_litellm_response = {
"data": [
{
"id": "gpt-4o",
"object": "model",
"created": 1700000000,
"owned_by": "openai",
},
]
}
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_litellm_response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000/",
api_key="test-key",
)
get_litellm_available_models(request, MagicMock(), mock_session)
# Should call /v1/models without double slashes
call_args = mock_get.call_args
assert call_args[0][0] == "http://localhost:4000/v1/models"
def test_connection_failure_raises_onyx_error(self) -> None:
"""Test that connection failures are wrapped in OnyxError."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_get.side_effect = Exception("Connection refused")
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError, match="Failed to fetch LiteLLM models"):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_401_raises_authentication_error(self) -> None:
"""Test that a 401 response raises OnyxError with authentication message."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.status_code = 401
mock_get.side_effect = httpx.HTTPStatusError(
"Unauthorized", request=MagicMock(), response=mock_response
)
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="bad-key",
)
with pytest.raises(OnyxError, match="Authentication failed"):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_404_raises_not_found_error(self) -> None:
"""Test that a 404 response raises OnyxError with endpoint not found message."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.status_code = 404
mock_get.side_effect = httpx.HTTPStatusError(
"Not Found", request=MagicMock(), response=mock_response
)
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError, match="endpoint not found"):
get_litellm_available_models(request, MagicMock(), mock_session)

View File

@@ -1,188 +0,0 @@
from io import BytesIO
from unittest.mock import MagicMock
import pytest
from fastapi import UploadFile
from onyx.server.features.projects import projects_file_utils as utils
class _Tokenizer:
def encode(self, text: str) -> list[int]:
return [1] * len(text)
class _NonSeekableFile(BytesIO):
def tell(self) -> int:
raise OSError("tell not supported")
def seek(self, *_args: object, **_kwargs: object) -> int:
raise OSError("seek not supported")
def _make_upload(filename: str, size: int, content: bytes | None = None) -> UploadFile:
payload = content if content is not None else (b"x" * size)
return UploadFile(filename=filename, file=BytesIO(payload), size=size)
def _make_upload_no_size(filename: str, content: bytes) -> UploadFile:
return UploadFile(filename=filename, file=BytesIO(content), size=None)
def _patch_common_dependencies(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(utils, "fetch_default_llm_model", lambda _db: None)
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _Tokenizer())
monkeypatch.setattr(utils, "is_file_password_protected", lambda **_kwargs: False)
def test_get_upload_size_bytes_falls_back_to_stream_size() -> None:
upload = UploadFile(filename="example.txt", file=BytesIO(b"abcdef"), size=None)
upload.file.seek(2)
size = utils.get_upload_size_bytes(upload)
assert size == 6
assert upload.file.tell() == 2
def test_get_upload_size_bytes_logs_warning_when_stream_size_unavailable(
caplog: pytest.LogCaptureFixture,
) -> None:
upload = UploadFile(filename="non_seekable.txt", file=_NonSeekableFile(), size=None)
caplog.set_level("WARNING")
size = utils.get_upload_size_bytes(upload)
assert size is None
assert "Could not determine upload size via stream seek" in caplog.text
assert "non_seekable.txt" in caplog.text
def test_is_upload_too_large_logs_warning_when_size_unknown(
monkeypatch: pytest.MonkeyPatch,
caplog: pytest.LogCaptureFixture,
) -> None:
upload = _make_upload("size_unknown.txt", size=1)
monkeypatch.setattr(utils, "get_upload_size_bytes", lambda _upload: None)
caplog.set_level("WARNING")
is_too_large = utils.is_upload_too_large(upload, max_bytes=100)
assert is_too_large is False
assert "Could not determine upload size; skipping size-limit check" in caplog.text
assert "size_unknown.txt" in caplog.text
def test_categorize_uploaded_files_accepts_size_under_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("small.png", size=99)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert len(result.rejected) == 0
def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload_no_size("small.png", content=b"x" * 99)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert len(result.rejected) == 0
def test_categorize_uploaded_files_accepts_size_at_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("edge.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert len(result.rejected) == 0
def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("large.png", size=101)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
small = _make_upload("small.png", size=50)
large = _make_upload("large.png", size=101)
result = utils.categorize_uploaded_files([small, large], MagicMock())
assert [file.filename for file in result.acceptable] == ["small.png"]
assert len(result.rejected) == 1
assert result.rejected[0].filename == "large.png"
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_skipped(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "SKIP_USERFILE_THRESHOLD", True)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
upload = _make_upload("oversized.pdf", size=101)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_uploaded_files_checks_size_before_text_extraction(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
extract_mock = MagicMock(return_value="this should not run")
monkeypatch.setattr(utils, "extract_file_text", extract_mock)
oversized_doc = _make_upload("oversized.pdf", size=101)
result = utils.categorize_uploaded_files([oversized_doc], MagicMock())
extract_mock.assert_not_called()
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"

View File

@@ -1,32 +0,0 @@
import pytest
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.settings import store as settings_store
class _FakeKvStore:
def load(self, _key: str) -> dict:
raise KvKeyNotFoundError()
class _FakeCache:
def __init__(self) -> None:
self._vals: dict[str, bytes] = {}
def get(self, key: str) -> bytes | None:
return self._vals.get(key)
def set(self, key: str, value: str, ex: int | None = None) -> None: # noqa: ARG002
self._vals[key] = value.encode("utf-8")
def test_load_settings_includes_user_file_max_upload_size_mb(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "USER_FILE_MAX_UPLOAD_SIZE_MB", 77)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 77

View File

@@ -147,18 +147,15 @@ class TestSensitiveValueString:
)
assert sensitive1 != sensitive2
def test_equality_with_non_sensitive_returns_not_equal(self) -> None:
"""Test that comparing with non-SensitiveValue is always not-equal.
Returns NotImplemented so Python falls back to identity comparison.
This is required for compatibility with SQLAlchemy's attribute tracking.
"""
def test_equality_with_non_sensitive_raises(self) -> None:
"""Test that comparing with non-SensitiveValue raises error."""
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string("secret"),
decrypt_fn=_decrypt_string,
is_json=False,
)
assert not (sensitive == "secret")
with pytest.raises(SensitiveAccessError):
_ = sensitive == "secret"
class TestSensitiveValueJson:

View File

@@ -158,14 +158,14 @@ python ./scripts/dev_run_background_jobs.py
To run the backend API server, navigate back to `onyx/backend` and run:
```bash
AUTH_TYPE=basic uvicorn onyx.main:app --reload --port 8080
AUTH_TYPE=disabled uvicorn onyx.main:app --reload --port 8080
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "
$env:AUTH_TYPE='basic'
$env:AUTH_TYPE='disabled'
uvicorn onyx.main:app --reload --port 8080
"
```

View File

@@ -21,7 +21,7 @@ services:
env_file:
- .env_eval
environment:
- AUTH_TYPE=basic
- AUTH_TYPE=disabled
- POSTGRES_HOST=relational_db
- VESPA_HOST=index
- REDIS_HOST=cache
@@ -58,7 +58,7 @@ services:
env_file:
- .env_eval
environment:
- AUTH_TYPE=basic
- AUTH_TYPE=disabled
- POSTGRES_HOST=relational_db
- VESPA_HOST=index
- REDIS_HOST=cache

View File

@@ -20,12 +20,8 @@ IMAGE_TAG=latest
## Auth Settings
### https://docs.onyx.app/deployment/authentication
AUTH_TYPE=basic
AUTH_TYPE=disabled
# SESSION_EXPIRE_TIME_SECONDS=
### Recommended for basic auth - used for signing password reset and verification tokens
### If using install.sh, this will be auto-generated
### If setting manually, run: openssl rand -hex 32
USER_AUTH_SECRET=""
### Recommend to set this for security
# ENCRYPTION_KEY_SECRET=
### Optional

View File

@@ -654,20 +654,17 @@ else
sed -i.bak "s/^IMAGE_TAG=.*/IMAGE_TAG=$VERSION/" "$ENV_FILE"
print_success "IMAGE_TAG set to $VERSION"
# Configure basic authentication (default)
sed -i.bak 's/^AUTH_TYPE=.*/AUTH_TYPE=basic/' "$ENV_FILE" 2>/dev/null || true
print_success "Basic authentication enabled in configuration"
# Check if openssl is available
if ! command -v openssl &> /dev/null; then
print_error "openssl is required to generate secure secrets but was not found."
exit 1
# Configure authentication settings based on selection
if [ "$AUTH_SCHEMA" = "disabled" ]; then
# Disable authentication in .env file
sed -i.bak 's/^AUTH_TYPE=.*/AUTH_TYPE=disabled/' "$ENV_FILE" 2>/dev/null || true
print_success "Authentication disabled in configuration"
else
# Enable basic authentication
sed -i.bak 's/^AUTH_TYPE=.*/AUTH_TYPE=basic/' "$ENV_FILE" 2>/dev/null || true
print_success "Basic authentication enabled in configuration"
fi
# Generate a secure USER_AUTH_SECRET
USER_AUTH_SECRET=$(openssl rand -hex 32)
sed -i.bak "s/^USER_AUTH_SECRET=.*/USER_AUTH_SECRET=\"$USER_AUTH_SECRET\"/" "$ENV_FILE" 2>/dev/null || true
# Configure Craft based on flag or if using a craft-* image tag
# By default, env.template has Craft commented out (disabled)
if [ "$INCLUDE_CRAFT" = true ] || [[ "$VERSION" == craft-* ]]; then

View File

@@ -1266,5 +1266,3 @@ configMap:
SKIP_USERFILE_THRESHOLD: ""
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
SKIP_USERFILE_THRESHOLD_TENANT_IDS: ""
# Maximum user upload file size in MB for chat/projects uploads
USER_FILE_MAX_UPLOAD_SIZE_MB: ""

View File

@@ -143,7 +143,7 @@ dev = [
"matplotlib==3.10.8",
"mypy-extensions==1.0.0",
"mypy==1.13.0",
"onyx-devtools==0.6.3",
"onyx-devtools==0.7.0",
"openapi-generator-cli==7.17.0",
"pandas-stubs~=2.3.3",
"pre-commit==3.2.2",
@@ -153,7 +153,7 @@ dev = [
"pytest-repeat==0.9.4",
"pytest-xdist==3.8.0",
"pytest==8.3.5",
"release-tag==0.4.3",
"release-tag==0.5.2",
"reorder-python-imports-black==3.14.0",
"ruff==0.12.0",
"types-beautifulsoup4==4.12.0.3",

35
uv.lock generated
View File

@@ -4443,7 +4443,7 @@ requires-dist = [
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.6.2" },
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.3" },
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.7.0" },
{ name = "openai", specifier = "==2.14.0" },
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
@@ -4485,7 +4485,7 @@ requires-dist = [
{ name = "pywikibot", marker = "extra == 'backend'", specifier = "==9.0.0" },
{ name = "rapidfuzz", marker = "extra == 'backend'", specifier = "==3.13.0" },
{ name = "redis", marker = "extra == 'backend'", specifier = "==5.0.8" },
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.4.3" },
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.5.2" },
{ name = "reorder-python-imports-black", marker = "extra == 'dev'", specifier = "==3.14.0" },
{ name = "requests", marker = "extra == 'backend'", specifier = "==2.32.5" },
{ name = "requests-oauthlib", marker = "extra == 'backend'", specifier = "==1.3.1" },
@@ -4548,20 +4548,19 @@ requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable =
[[package]]
name = "onyx-devtools"
version = "0.6.3"
version = "0.7.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "fastapi" },
{ name = "openapi-generator-cli" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/84/e2/e7619722c3ccd18eb38100f776fb3dd6b4ae0fbbee09fca5af7c69a279b5/onyx_devtools-0.6.3-py3-none-any.whl", hash = "sha256:d3a5422945d9da12cafc185f64b39f6e727ee4cc92b37427deb7a38f9aad4966", size = 3945381, upload-time = "2026-03-05T20:39:25.896Z" },
{ url = "https://files.pythonhosted.org/packages/f2/09/513d2dabedc1e54ad4376830fc9b34a3d9c164bdbcdedfcdbb8b8154dc5a/onyx_devtools-0.6.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:efe300e9f3a2e7ae75f88a4f9e0a5c4c471478296cb1615b6a1f03d247582e13", size = 3978761, upload-time = "2026-03-05T20:39:28.822Z" },
{ url = "https://files.pythonhosted.org/packages/39/41/e757602a0de032d74ed01c7ee57f30e57728fb9cd4f922f50d2affda3889/onyx_devtools-0.6.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:594066eed3f917cfab5a8c7eac3d4a210df30259f2049f664787749709345e19", size = 3665378, upload-time = "2026-03-05T20:44:22.696Z" },
{ url = "https://files.pythonhosted.org/packages/33/1c/c93b65d0b32e202596a2647922a75c7011cb982f899ddfcfd171f792c58f/onyx_devtools-0.6.3-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:384ef66030b55c0fd68b3898782b5b4b868ff3de119569dfc8544e2ce534b98a", size = 3540890, upload-time = "2026-03-05T20:39:28.886Z" },
{ url = "https://files.pythonhosted.org/packages/f4/33/760eb656013f7f0cdff24570480d3dc4e52bbd8e6147ea1e8cf6fad7554f/onyx_devtools-0.6.3-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:82e218f3a49f64910c2c4c34d5dc12d1ea1520a27e0b0f6e4c0949ff9abaf0e1", size = 3945396, upload-time = "2026-03-05T20:39:34.323Z" },
{ url = "https://files.pythonhosted.org/packages/1a/eb/f54b3675c464df8a51194ff75afc97c2417659e3a209dc46948b47c28860/onyx_devtools-0.6.3-py3-none-win_amd64.whl", hash = "sha256:8af614ae7229290ef2417cb85270184a1e826ed9a3a34658da93851edb36df57", size = 4045936, upload-time = "2026-03-05T20:39:28.375Z" },
{ url = "https://files.pythonhosted.org/packages/04/b8/5bee38e748f3d4b8ec935766224db1bbc1214c91092e5822c080fccd9130/onyx_devtools-0.6.3-py3-none-win_arm64.whl", hash = "sha256:717589db4b42528d33ae96f8006ee6aad3555034dcfee724705b6576be6a6ec4", size = 3608268, upload-time = "2026-03-05T20:39:28.731Z" },
{ url = "https://files.pythonhosted.org/packages/22/9e/6957b11555da57d9e97092f4cd8ac09a86666264b0c9491838f4b27db5dc/onyx_devtools-0.7.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ad962a168d46ea11dcde9fa3b37e4f12ec520b4a4cb4d49d8732de110d46c4b6", size = 3998057, upload-time = "2026-03-12T03:09:11.585Z" },
{ url = "https://files.pythonhosted.org/packages/cd/90/c72f3d06ba677012d77c77de36195b6a32a15c755c79ba0282be74e3c366/onyx_devtools-0.7.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e46d252e2b048ff053b03519c3a875998780738d7c334eaa1c9a32ff445e3e1a", size = 3687753, upload-time = "2026-03-12T03:09:11.742Z" },
{ url = "https://files.pythonhosted.org/packages/10/42/4e9fe36eccf9f76d67ba8f4ff6539196a09cd60351fb63f5865e1544cbfa/onyx_devtools-0.7.0-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:f280bc9320e1cc310e7d753a371009bfaab02cc0e0cfd78559663b15655b5a50", size = 3560144, upload-time = "2026-03-12T03:12:24.02Z" },
{ url = "https://files.pythonhosted.org/packages/76/40/36dc12d99760b358c7f39b27361cb18fa9681ffe194107f982d0e1a74016/onyx_devtools-0.7.0-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:e31df751c7540ae7e70a7fe8e1153c79c31c2254af6aa4c72c0dd54fa381d2ab", size = 3964387, upload-time = "2026-03-12T03:09:11.356Z" },
{ url = "https://files.pythonhosted.org/packages/34/18/74744230c3820a5a7687335507ca5f1dbebab2c5325805041c1cd5703e6a/onyx_devtools-0.7.0-py3-none-win_amd64.whl", hash = "sha256:541bfd347c2d5b11e7f63ab5001d2594df91d215ad9d07b1562f5e715700f7e6", size = 4068030, upload-time = "2026-03-12T03:09:12.98Z" },
{ url = "https://files.pythonhosted.org/packages/8c/78/1320436607d3ffcb321ba7b064556c020ea15843a7e7d903fbb7529a71f5/onyx_devtools-0.7.0-py3-none-win_arm64.whl", hash = "sha256:83016330a9d39712431916cc25b2fb2cfcaa0112a55cc4f919d545da3a8974f9", size = 3626409, upload-time = "2026-03-12T03:09:10.222Z" },
]
[[package]]
@@ -6338,16 +6337,16 @@ wheels = [
[[package]]
name = "release-tag"
version = "0.4.3"
version = "0.5.2"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/39/18/c1d17d973f73f0aa7e2c45f852839ab909756e1bd9727d03babe400fcef0/release_tag-0.4.3-py3-none-any.whl", hash = "sha256:4206f4fa97df930c8176bfee4d3976a7385150ed14b317bd6bae7101ac8b66dd", size = 1181112, upload-time = "2025-12-03T00:18:19.445Z" },
{ url = "https://files.pythonhosted.org/packages/33/c7/ecc443953840ac313856b2181f55eb8d34fa2c733cdd1edd0bcceee0938d/release_tag-0.4.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7a347a9ad3d2af16e5367e52b451fbc88a0b7b666850758e8f9a601554a8fb13", size = 1170517, upload-time = "2025-12-03T00:18:11.663Z" },
{ url = "https://files.pythonhosted.org/packages/ce/81/2f6ffa0d87c792364ca9958433fe088c8acc3d096ac9734040049c6ad506/release_tag-0.4.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2d1603aa37d8e4f5df63676bbfddc802fbc108a744ba28288ad25c997981c164", size = 1101663, upload-time = "2025-12-03T00:18:15.173Z" },
{ url = "https://files.pythonhosted.org/packages/7c/ed/9e4ebe400fc52e38dda6e6a45d9da9decd4535ab15e170b8d9b229a66730/release_tag-0.4.3-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:6db7b81a198e3ba6a87496a554684912c13f9297ea8db8600a80f4f971709d37", size = 1079322, upload-time = "2025-12-03T00:18:16.094Z" },
{ url = "https://files.pythonhosted.org/packages/2a/64/9e0ce6119e091ef9211fa82b9593f564eeec8bdd86eff6a97fe6e2fcb20f/release_tag-0.4.3-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:d79a9cf191dd2c29e1b3a35453fa364b08a7aadd15aeb2c556a7661c6cf4d5ad", size = 1181129, upload-time = "2025-12-03T00:18:15.82Z" },
{ url = "https://files.pythonhosted.org/packages/b8/09/d96acf18f0773b6355080a568ba48931faa9dbe91ab1abefc6f8c4df04a8/release_tag-0.4.3-py3-none-win_amd64.whl", hash = "sha256:3958b880375f2241d0cc2b9882363bf54b1d4d7ca8ffc6eecc63ab92f23307f0", size = 1260773, upload-time = "2025-12-03T00:18:14.723Z" },
{ url = "https://files.pythonhosted.org/packages/51/da/ecb6346df1ffb0752fe213e25062f802c10df2948717f0d5f9816c2df914/release_tag-0.4.3-py3-none-win_arm64.whl", hash = "sha256:7d5b08000e6e398d46f05a50139031046348fba6d47909f01e468bb7600c19df", size = 1142155, upload-time = "2025-12-03T00:18:20.647Z" },
{ url = "https://files.pythonhosted.org/packages/ab/92/01192a540b29cfadaa23850c8f6a2041d541b83a3fa1dc52a5f55212b3b6/release_tag-0.5.2-py3-none-any.whl", hash = "sha256:1e9ca7618bcfc63ad7a0728c84bbad52ef82d07586c4cc11365b44ea8f588069", size = 1264752, upload-time = "2026-03-11T00:27:18.674Z" },
{ url = "https://files.pythonhosted.org/packages/4f/77/81fb42a23cd0de61caf84266f7aac1950b1c324883788b7c48e5344f61ae/release_tag-0.5.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8fbc61ff7bac2b96fab09566ec45c6508c201efc3f081f57702e1761bbc178d5", size = 1255075, upload-time = "2026-03-11T00:27:24.442Z" },
{ url = "https://files.pythonhosted.org/packages/98/e6/769f8be94304529c1a531e995f2f3ac83f3c54738ce488b0abde75b20851/release_tag-0.5.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa3d7e495a0c516858a81878d03803539712677a3d6e015503de21cce19bea5e", size = 1163627, upload-time = "2026-03-11T00:27:26.412Z" },
{ url = "https://files.pythonhosted.org/packages/45/68/7543e9daa0dfd41c487bf140d91fd5879327bb7c001a96aa5264667c30a1/release_tag-0.5.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:e8b60453218d6926da1fdcb99c2e17c851be0d7ab1975e97951f0bff5f32b565", size = 1140133, upload-time = "2026-03-11T00:27:20.633Z" },
{ url = "https://files.pythonhosted.org/packages/6a/30/9087825696271012d889d136310dbdf0811976ae2b2f5a490f4e437903e1/release_tag-0.5.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:0e302ed60c2bf8b7ba5634842be28a27d83cec995869e112b0348b3f01a84ff5", size = 1264767, upload-time = "2026-03-11T00:27:28.355Z" },
{ url = "https://files.pythonhosted.org/packages/79/a3/5b51b0cbdbf2299f545124beab182cfdfe01bf5b615efbc94aee3a64ea67/release_tag-0.5.2-py3-none-win_amd64.whl", hash = "sha256:e3c0629d373a16b9a3da965e89fca893640ce9878ec548865df3609b70989a89", size = 1340816, upload-time = "2026-03-11T00:27:22.622Z" },
{ url = "https://files.pythonhosted.org/packages/dd/6f/832c2023a8bd8414c93452bd8b43bf61cedfa5b9575f70c06fb911e51a29/release_tag-0.5.2-py3-none-win_arm64.whl", hash = "sha256:5f26b008e0be0c7a122acd8fcb1bb5c822f38e77fed0c0bf6c550cc226c6bf14", size = 1203191, upload-time = "2026-03-11T00:27:29.789Z" },
]
[[package]]

3
web/.gitignore vendored
View File

@@ -47,6 +47,3 @@ next-env.d.ts
# generated clients ... in particular, the API to the Onyx backend itself!
/src/lib/generated
.jest-cache
# storybook
storybook-static

View File

@@ -1,115 +0,0 @@
import { Meta } from "@storybook/blocks";
<Meta title="Getting Started" />
# Onyx Storybook
A living catalog for browsing, testing, and documenting Onyx UI components in isolation.
---
## What is this?
This Storybook contains interactive examples of every reusable UI component in the Onyx frontend. Each component has a dedicated page with:
- **Live demos** you can interact with directly
- **Controls** to tweak props and see how the component responds
- **Auto-generated docs** showing the full props API
- **Dark mode toggle** in the toolbar to preview both themes
---
## Navigating Storybook
### Sidebar
The left sidebar organizes components by layer:
- **opal/core** — Low-level primitives (`Interactive`, `Hoverable`)
- **opal/components** — Design system atoms (`Button`, `OpenButton`, `Tag`)
- **Layouts** — Structural layouts (`Content`, `ContentAction`, `IllustrationContent`)
- **refresh-components** — App-level components (inputs, modals, tables, text, etc.)
Click any component to see its stories. Click **Docs** to see the auto-generated props table.
### Controls panel
At the bottom of each story, the **Controls** panel lets you change props in real time. Toggle booleans, pick from enums, type in strings — the preview updates instantly.
### Theme toggle
Use the paint roller icon in the top toolbar to switch between **light** and **dark** mode. All components use CSS variables that automatically adapt.
---
## Running locally
```bash
cd web
npm run storybook # dev server on :6006
npm run storybook:build # static build to storybook-static/
```
---
## Adding a new story
Stories are **co-located** next to their component:
```
lib/opal/src/components/buttons/Button/
├── components.tsx ← the component
├── Button.stories.tsx ← the story
├── styles.css
└── README.md
```
### Minimal template
```tsx
import type { Meta, StoryObj } from "@storybook/react";
import { MyComponent } from "./MyComponent";
const meta: Meta<typeof MyComponent> = {
title: "opal/components/MyComponent", // sidebar path
component: MyComponent,
tags: ["autodocs"], // auto-generate docs page
};
export default meta;
type Story = StoryObj<typeof MyComponent>;
export const Default: Story = {
args: {
title: "Hello",
},
};
export const WithCustomLayout: Story = {
render: () => (
<div className="flex gap-2">
<MyComponent title="One" />
<MyComponent title="Two" />
</div>
),
};
```
### Conventions
- **Title format:** `opal/core/Name`, `opal/components/Name`, `Layouts/Name`, or `refresh-components/Name`
- **Tags:** Add `tags: ["autodocs"]` to auto-generate a docs page from props
- **Decorators:** If your component needs `TooltipPrimitive.Provider` (anything with tooltips), add it as a decorator
- **Layout:** Use `parameters: { layout: "fullscreen" }` for modals/popovers that use portals
---
## Deployment
Production builds deploy to [onyx-storybook.vercel.app](https://onyx-storybook.vercel.app) automatically when PRs touching component files merge to `main`.
Monitored paths:
- `web/lib/opal/**`
- `web/src/refresh-components/**`
- `web/.storybook/**`

Some files were not shown because too many files have changed in this diff Show More