Compare commits

..

11 Commits

268 changed files with 6306 additions and 5375 deletions

View File

@@ -8,5 +8,5 @@
## Additional Options
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
- [ ] [Optional] Override Linear Check

View File

@@ -0,0 +1,151 @@
# Scan for problematic software licenses
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
name: 'Nightly - Scan licenses'
on:
# schedule:
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
workflow_dispatch: # Allows manual triggering
permissions:
actions: read
contents: read
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
timeout-minutes: 45
permissions:
actions: read
contents: read
security-events: write
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- name: Get explicit and transitive dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip freeze > requirements-all.txt
- name: Check python
id: license_check_report
uses: pilosus/action-pip-license-checker@e909b0226ff49d3235c99c4585bc617f49fff16a # ratchet:pilosus/action-pip-license-checker@v3
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft'
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: always()
env:
REPORT: ${{ steps.license_check_report.outputs.report }}
run: echo "$REPORT"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
# be careful enabling the sarif and upload as it may spam the security tab
# with a huge amount of items. Work out the issues before enabling upload.
# - name: Run Trivy vulnerability scanner in repo mode
# if: always()
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
# with:
# scan-type: fs
# scan-ref: .
# scanners: license
# format: table
# severity: HIGH,CRITICAL
# # format: sarif
# # output: trivy-results.sarif
#
# # - name: Upload Trivy scan results to GitHub Security tab
# # uses: github/codeql-action/upload-sarif@v3
# # with:
# # sarif_file: trivy-results.sarif
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# Backend
- name: Pull backend docker image
run: docker pull onyxdotapp/onyx-backend:latest
- name: Run Trivy vulnerability scanner on backend
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-backend:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
# Web server
- name: Pull web server docker image
run: docker pull onyxdotapp/onyx-web-server:latest
- name: Run Trivy vulnerability scanner on web server
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-web-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# Model server
- name: Pull model server docker image
run: docker pull onyxdotapp/onyx-model-server:latest
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-model-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0

View File

@@ -1,79 +0,0 @@
name: Post-Merge Beta Cherry-Pick
on:
push:
branches:
- main
permissions:
contents: write
pull-requests: write
jobs:
cherry-pick-to-latest-release:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Resolve merged PR and checkbox state
id: gate
env:
GH_TOKEN: ${{ github.token }}
run: |
# For the commit that triggered this workflow (HEAD on main), fetch all
# associated PRs and keep only the PR that was actually merged into main
# with this exact merge commit SHA.
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
if [ "${match_count}" -gt 1 ]; then
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
fi
if [ -z "$pr_number" ]; then
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
exit 0
fi
# Read the PR body and check whether the helper checkbox is checked.
pr_body="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}" --jq '.body // ""')"
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox checked for PR #${pr_number}."
exit 0
fi
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
- name: Checkout repository
if: steps.gate.outputs.should_cherrypick == 'true'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: true
ref: main
- name: Install the latest version of uv
if: steps.gate.outputs.should_cherrypick == 'true'
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Configure git identity
if: steps.gate.outputs.should_cherrypick == 'true'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Create cherry-pick PR to latest release
if: steps.gate.outputs.should_cherrypick == 'true'
env:
GH_TOKEN: ${{ github.token }}
GITHUB_TOKEN: ${{ github.token }}
run: |
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify

View File

@@ -0,0 +1,28 @@
name: Require beta cherry-pick consideration
concurrency:
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
pull_request:
types: [opened, edited, reopened, synchronize]
permissions:
contents: read
jobs:
beta-cherrypick-check:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR body for beta cherry-pick consideration
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
echo "Cherry-pick consideration box is checked. Check passed."
exit 0
fi
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
exit 1

View File

@@ -41,7 +41,8 @@ jobs:
version: v3.19.0
- name: Set up chart-testing
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
with:
uv_version: "0.9.9"

View File

@@ -22,9 +22,6 @@ env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
GOOGLE_PSE_API_KEY: ${{ secrets.GOOGLE_PSE_API_KEY }}
GOOGLE_PSE_SEARCH_ENGINE_ID: ${{ secrets.GOOGLE_PSE_SEARCH_ENGINE_ID }}
# for federated slack tests
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
@@ -470,3 +467,48 @@ jobs:
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1
# NOTE: Chromatic UI diff testing is currently disabled.
# We are using Playwright for local and CI testing without visual regression checks.
# Chromatic may be reintroduced in the future for UI diff testing if needed.
# chromatic-tests:
# name: Chromatic Tests
# needs: playwright-tests
# runs-on:
# [
# runs-on,
# runner=32cpu-linux-x64,
# disk=large,
# "run-id=${{ github.run_id }}",
# ]
# steps:
# - name: Checkout code
# uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
# with:
# fetch-depth: 0
# - name: Setup node
# uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
# with:
# node-version: 22
# - name: Install node dependencies
# working-directory: ./web
# run: npm ci
# - name: Download Playwright test results
# uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # ratchet:actions/download-artifact@v4
# with:
# name: test-results
# path: ./web/test-results
# - name: Run Chromatic
# uses: chromaui/action@latest
# with:
# playwright: true
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
# workingDir: ./web
# env:
# CHROMATIC_ARCHIVE_LOCATION: ./test-results

View File

@@ -2,10 +2,7 @@ Copyright (c) 2023-present DanswerAI, Inc.
Portions of this software are licensed as follows:
- All content that resides under "ee" directories of this repository is licensed under the Onyx Enterprise License. Each ee directory contains an identical copy of this license at its root:
- backend/ee/LICENSE
- web/src/app/ee/LICENSE
- web/src/ee/LICENSE
- All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
- All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component.
- Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.

View File

@@ -134,7 +134,6 @@ COPY --chown=onyx:onyx ./alembic_tenants /app/alembic_tenants
COPY --chown=onyx:onyx ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
COPY --chown=onyx:onyx ./static /app/static
COPY --chown=onyx:onyx ./keys /app/keys
# Escape hatch scripts
COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging

View File

@@ -0,0 +1,27 @@
"""add_user_preferences
Revision ID: 175ea04c7087
Revises: d56ffa94ca32
Create Date: 2026-02-04 18:16:24.830873
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "175ea04c7087"
down_revision = "d56ffa94ca32"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("user_preferences", sa.Text(), nullable=True),
)
def downgrade() -> None:
op.drop_column("user", "user_preferences")

View File

@@ -1,20 +1,20 @@
The Onyx Enterprise License (the "Enterprise License")
The DanswerAI Enterprise license (the Enterprise License)
Copyright (c) 2023-present DanswerAI, Inc.
With regard to the Onyx Software:
This software and associated documentation files (the "Software") may only be
used in production, if you (and any entity that you represent) have agreed to,
and are in compliance with, the Onyx Subscription Terms of Service, available
at https://www.onyx.app/legal/self-host (the "Enterprise Terms"), or other
and are in compliance with, the DanswerAI Subscription Terms of Service, available
at https://onyx.app/terms (the Enterprise Terms), or other
agreement governing the use of the Software, as agreed by you and DanswerAI,
and otherwise have a valid Onyx Enterprise License for the
and otherwise have a valid Onyx Enterprise license for the
correct number of user seats. Subject to the foregoing sentence, you are free to
modify this Software and publish patches to the Software. You agree that DanswerAI
and/or its licensors (as applicable) retain all right, title and interest in and
to all such modifications and/or patches, and all such modifications and/or
patches may only be used, copied, modified, displayed, distributed, or otherwise
exploited with a valid Onyx Enterprise License for the correct
exploited with a valid Onyx Enterprise license for the correct
number of user seats. Notwithstanding the foregoing, you may copy and modify
the Software for development and testing purposes, without requiring a
subscription. You agree that DanswerAI and/or its licensors (as applicable) retain

View File

@@ -263,15 +263,9 @@ def refresh_license_cache(
try:
payload = verify_license_signature(license_record.license_data)
# Derive source from payload: manual licenses lack stripe_customer_id
source: LicenseSource = (
LicenseSource.AUTO_FETCH
if payload.stripe_customer_id
else LicenseSource.MANUAL_UPLOAD
)
return update_license_cache(
payload,
source=source,
source=LicenseSource.AUTO_FETCH,
tenant_id=tenant_id,
)
except ValueError as e:

View File

@@ -50,7 +50,12 @@ def github_doc_sync(
**cc_pair.connector.connector_specific_config
)
github_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
github_connector.load_credentials(credential_json)
logger.info("GitHub connector credentials loaded successfully")
if not github_connector.github_client:

View File

@@ -18,7 +18,12 @@ def github_group_sync(
github_connector: GithubConnector = GithubConnector(
**cc_pair.connector.connector_specific_config
)
github_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
github_connector.load_credentials(credential_json)
if not github_connector.github_client:
raise ValueError("github_client is required")

View File

@@ -50,7 +50,12 @@ def gmail_doc_sync(
already populated.
"""
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
gmail_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
gmail_connector.load_credentials(credential_json)
slim_doc_generator = _get_slim_doc_generator(
cc_pair, gmail_connector, callback=callback

View File

@@ -295,7 +295,12 @@ def gdrive_doc_sync(
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
)
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
google_drive_connector.load_credentials(credential_json)
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)

View File

@@ -391,7 +391,12 @@ def gdrive_group_sync(
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
)
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
google_drive_connector.load_credentials(credential_json)
admin_service = get_admin_service(
google_drive_connector.creds, google_drive_connector.primary_admin_email
)

View File

@@ -24,7 +24,12 @@ def jira_doc_sync(
jira_connector = JiraConnector(
**cc_pair.connector.connector_specific_config,
)
jira_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
jira_connector.load_credentials(credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,

View File

@@ -119,8 +119,13 @@ def jira_group_sync(
if not jira_base_url:
raise ValueError("No jira_base_url found in connector config")
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
jira_client = build_jira_client(
credentials=cc_pair.credential.credential_json,
credentials=credential_json,
jira_base=jira_base_url,
scoped_token=scoped_token,
)

View File

@@ -30,7 +30,11 @@ def get_any_salesforce_client_for_doc_id(
if _ANY_SALESFORCE_CLIENT is None:
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
first_cc_pair = cc_pairs[0]
credential_json = first_cc_pair.credential.credential_json
credential_json = (
first_cc_pair.credential.credential_json.get_value(apply_mask=False)
if first_cc_pair.credential.credential_json
else {}
)
_ANY_SALESFORCE_CLIENT = Salesforce(
username=credential_json["sf_username"],
password=credential_json["sf_password"],
@@ -158,7 +162,11 @@ def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Sales
)
if cc_pair is None:
raise ValueError(f"CC pair {cc_pair_id} not found")
credential_json = cc_pair.credential.credential_json
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce(
username=credential_json["sf_username"],
password=credential_json["sf_password"],

View File

@@ -24,7 +24,12 @@ def sharepoint_doc_sync(
sharepoint_connector = SharepointConnector(
**cc_pair.connector.connector_specific_config,
)
sharepoint_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
sharepoint_connector.load_credentials(credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,

View File

@@ -25,7 +25,12 @@ def sharepoint_group_sync(
# Create SharePoint connector instance and load credentials
connector = SharepointConnector(**connector_config)
connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
connector.load_credentials(credential_json)
if not connector.msal_app:
raise RuntimeError("MSAL app not initialized in connector")

View File

@@ -151,9 +151,14 @@ def slack_doc_sync(
tenant_id = get_current_tenant_id()
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
r = get_redis_client(tenant_id=tenant_id)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
slack_client = SlackConnector.make_slack_web_client(
provider.get_provider_key(),
cc_pair.credential.credential_json["slack_bot_token"],
credential_json["slack_bot_token"],
SlackConnector.MAX_RETRIES,
r,
)

View File

@@ -63,9 +63,14 @@ def slack_group_sync(
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
r = get_redis_client(tenant_id=tenant_id)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
slack_client = SlackConnector.make_slack_web_client(
provider.get_provider_key(),
cc_pair.credential.credential_json["slack_bot_token"],
credential_json["slack_bot_token"],
SlackConnector.MAX_RETRIES,
r,
)

View File

@@ -25,7 +25,12 @@ def teams_doc_sync(
teams_connector = TeamsConnector(
**cc_pair.connector.connector_specific_config,
)
teams_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
teams_connector.load_credentials(credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,

View File

@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
from ee.onyx.server.billing.api import router as billing_router
@@ -150,9 +151,12 @@ def get_application() -> FastAPI:
# License management
include_router_with_global_prefix_prepended(application, license_router)
# Unified billing API - always registered in EE.
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
include_router_with_global_prefix_prepended(application, billing_router)
# Unified billing API - available when license system is enabled
# Works for both self-hosted and cloud deployments
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
# primary billing API and /tenants/* billing endpoints can be removed
if LICENSE_ENFORCEMENT_ENABLED:
include_router_with_global_prefix_prepended(application, billing_router)
if MULTI_TENANT:
# Tenant management

View File

@@ -109,9 +109,7 @@ async def _make_billing_request(
headers = _get_headers(license_data)
try:
async with httpx.AsyncClient(
timeout=_REQUEST_TIMEOUT, follow_redirects=True
) as client:
async with httpx.AsyncClient(timeout=_REQUEST_TIMEOUT) as client:
if method == "GET":
response = await client.get(url, headers=headers, params=params)
else:

View File

@@ -270,7 +270,11 @@ def confluence_oauth_accessible_resources(
if not credential:
raise HTTPException(400, f"Credential {credential_id} not found.")
credential_dict = credential.credential_json
credential_dict = (
credential.credential_json.get_value(apply_mask=False)
if credential.credential_json
else {}
)
access_token = credential_dict["confluence_access_token"]
try:
@@ -337,7 +341,12 @@ def confluence_oauth_finalize(
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
)
new_credential_json: dict[str, Any] = dict(credential.credential_json)
existing_credential_json = (
credential.credential_json.get_value(apply_mask=False)
if credential.credential_json
else {}
)
new_credential_json: dict[str, Any] = dict(existing_credential_json)
new_credential_json["cloud_id"] = cloud_id
new_credential_json["cloud_name"] = cloud_name
new_credential_json["wiki_base"] = cloud_url

View File

@@ -1,13 +1,9 @@
"""EE Settings API - provides license-aware settings override."""
from redis.exceptions import RedisError
from sqlalchemy.exc import SQLAlchemyError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import get_cached_license_metadata
from ee.onyx.db.license import refresh_license_cache
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
@@ -44,14 +40,6 @@ def check_ee_features_enabled() -> bool:
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if not metadata:
# Cache miss — warm from DB so cold-start doesn't block EE features
try:
with get_session_with_current_tenant() as db_session:
metadata = refresh_license_cache(db_session, tenant_id)
except SQLAlchemyError as db_error:
logger.warning(f"Failed to load license from DB: {db_error}")
if metadata and metadata.status != _BLOCKING_STATUS:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
return True
@@ -93,18 +81,6 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if not metadata:
# Cache miss (e.g. after TTL expiry). Fall back to DB so
# the /settings request doesn't falsely return GATED_ACCESS
# while the cache is cold.
try:
with get_session_with_current_tenant() as db_session:
metadata = refresh_license_cache(db_session, tenant_id)
except SQLAlchemyError as db_error:
logger.warning(
f"Failed to load license from DB for settings: {db_error}"
)
if metadata:
if metadata.status == _BLOCKING_STATUS:
settings.application_status = metadata.status
@@ -113,11 +89,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
settings.ee_features_enabled = True
else:
# No license found in cache or DB.
if ENTERPRISE_EDITION_ENABLED:
# Legacy EE flag is set → prior EE usage (e.g. permission
# syncing) means indexed data may need protection.
settings.application_status = _BLOCKING_STATUS
# No license = community edition, disable EE features
settings.ee_features_enabled = False
except RedisError as e:
logger.warning(f"Failed to check license metadata for settings: {e}")

View File

@@ -177,7 +177,7 @@ async def forward_to_control_plane(
url = f"{CONTROL_PLANE_API_BASE_URL}{path}"
try:
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
async with httpx.AsyncClient(timeout=30.0) as client:
if method == "GET":
response = await client.get(url, headers=headers, params=params)
elif method == "POST":

View File

@@ -12,14 +12,12 @@ from ee.onyx.db.user_group import prepare_user_group_for_deletion
from ee.onyx.db.user_group import update_user_curator_relationship
from ee.onyx.db.user_group import update_user_group
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
from ee.onyx.server.user_group.models import SetCuratorRequest
from ee.onyx.server.user_group.models import UserGroup
from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
@@ -47,23 +45,6 @@ def list_user_groups(
return [UserGroup.from_model(user_group) for user_group in user_groups]
@router.get("/user-groups/minimal")
def list_minimal_user_groups(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[MinimalUserGroupSnapshot]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
)
return [
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
]
@router.post("/admin/user-group")
def create_user_group(
user_group: UserGroupCreate,

View File

@@ -76,18 +76,6 @@ class UserGroup(BaseModel):
)
class MinimalUserGroupSnapshot(BaseModel):
id: int
name: str
@classmethod
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
return cls(
id=user_group_model.id,
name=user_group_model.name,
)
class UserGroupCreate(BaseModel):
name: str
user_ids: list[UUID]

View File

@@ -11,6 +11,7 @@ from onyx.db.models import OAuthUserToken
from onyx.db.oauth_config import get_user_oauth_token
from onyx.db.oauth_config import upsert_user_oauth_token
from onyx.utils.logger import setup_logger
from onyx.utils.sensitive import SensitiveValue
logger = setup_logger()
@@ -33,7 +34,10 @@ class OAuthTokenManager:
if not user_token:
return None
token_data = user_token.token_data
if not user_token.token_data:
return None
token_data = self._unwrap_token_data(user_token.token_data)
# Check if token is expired
if OAuthTokenManager.is_token_expired(token_data):
@@ -51,7 +55,10 @@ class OAuthTokenManager:
def refresh_token(self, user_token: OAuthUserToken) -> str:
"""Refresh access token using refresh token"""
token_data = user_token.token_data
if not user_token.token_data:
raise ValueError("No token data available for refresh")
token_data = self._unwrap_token_data(user_token.token_data)
response = requests.post(
self.oauth_config.token_url,
@@ -153,3 +160,11 @@ class OAuthTokenManager:
separator = "&" if "?" in oauth_config.authorization_url else "?"
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
@staticmethod
def _unwrap_token_data(
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
) -> dict[str, Any]:
if isinstance(token_data, SensitiveValue):
return token_data.get_value(apply_mask=False)
return token_data

View File

@@ -60,7 +60,6 @@ from sqlalchemy import nulls_last
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.disposable_email_validator import is_disposable_email
@@ -111,7 +110,6 @@ from onyx.db.auth import get_user_db
from onyx.db.auth import SQLAlchemyUserAdminDB
from onyx.db.engine.async_sql_engine import get_async_session
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
@@ -274,22 +272,6 @@ def verify_email_domain(email: str) -> None:
)
def enforce_seat_limit(db_session: Session, seats_needed: int = 1) -> None:
"""Raise HTTPException(402) if adding users would exceed the seat limit.
No-op for multi-tenant or CE deployments.
"""
if MULTI_TENANT:
return
result = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)(db_session, seats_needed=seats_needed)
if result is not None and not result.available:
raise HTTPException(status_code=402, detail=result.error_message)
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
@@ -419,12 +401,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
):
user_create.role = UserRole.ADMIN
# Check seat availability for new users (single-tenant only)
with get_session_with_current_tenant() as sync_db:
existing = get_user_by_email(user_create.email, sync_db)
if existing is None:
enforce_seat_limit(sync_db)
user_created = False
try:
user = await super().create(user_create, safe=safe, request=request)
@@ -634,10 +610,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
raise exceptions.UserNotExists()
except exceptions.UserNotExists:
# Check seat availability before creating (single-tenant only)
with get_session_with_current_tenant() as sync_db:
enforce_seat_limit(sync_db)
password = self.password_helper.generate()
user_dict = {
"email": account_email,

View File

@@ -12,7 +12,6 @@ from retry import retry
from sqlalchemy import select
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import MANAGED_VESPA
@@ -20,14 +19,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
@@ -57,17 +54,6 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@@ -131,24 +117,7 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
Uses direct Redis locks to avoid overlapping runs.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -163,21 +132,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -190,35 +145,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -226,8 +158,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@@ -244,12 +175,6 @@ def process_single_user_file(
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,

View File

@@ -27,8 +27,8 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchDocsResponse
from onyx.db.memory import UserMemoryContext
from onyx.db.models import Persona
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.interfaces import ToolChoiceOptions
@@ -60,28 +60,6 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def _should_keep_bedrock_tool_definitions(
llm: object, simple_chat_history: list[ChatMessageSimple]
) -> bool:
"""Bedrock requires tool config when history includes toolUse/toolResult blocks."""
model_provider = getattr(getattr(llm, "config", None), "model_provider", None)
if model_provider not in {
LlmProviderNames.BEDROCK,
LlmProviderNames.BEDROCK_CONVERSE,
}:
return False
return any(
(
msg.message_type == MessageType.ASSISTANT
and msg.tool_calls
and len(msg.tool_calls) > 0
)
or msg.message_type == MessageType.TOOL_CALL_RESPONSE
for msg in simple_chat_history
)
def _try_fallback_tool_extraction(
llm_step_result: LlmStepResult,
tool_choice: ToolChoiceOptions,
@@ -393,7 +371,7 @@ def run_llm_loop(
custom_agent_prompt: str | None,
project_files: ExtractedProjectFiles,
persona: Persona | None,
memories: list[str] | None,
user_memory_context: UserMemoryContext | None,
llm: LLM,
token_counter: Callable[[str], int],
db_session: Session,
@@ -477,12 +455,7 @@ def run_llm_loop(
elif out_of_cycles or ran_image_gen:
# Last cycle, no tools allowed, just answer!
tool_choice = ToolChoiceOptions.NONE
# Bedrock requires tool config in requests that include toolUse/toolResult history.
final_tools = (
tools
if _should_keep_bedrock_tool_definitions(llm, simple_chat_history)
else []
)
final_tools = []
else:
tool_choice = ToolChoiceOptions.AUTO
final_tools = tools
@@ -511,7 +484,7 @@ def run_llm_loop(
system_prompt_str = build_system_prompt(
base_system_prompt=default_base_system_prompt,
datetime_aware=persona.datetime_aware if persona else True,
memories=memories,
user_memory_context=user_memory_context,
tools=tools,
should_cite_documents=should_cite_documents
or always_cite_documents,
@@ -665,7 +638,7 @@ def run_llm_loop(
tool_calls=tool_calls,
tools=final_tools,
message_history=truncated_message_history,
memories=memories,
user_memory_context=user_memory_context,
user_info=None, # TODO, this is part of memories right now, might want to separate it out
citation_mapping=citation_mapping,
next_citation_num=citation_processor.get_next_citation_number(),

View File

@@ -471,7 +471,7 @@ def handle_stream_message_objects(
# Filter chat_history to only messages after the cutoff
chat_history = [m for m in chat_history if m.id > cutoff_id]
memories = get_memories(user, db_session)
user_memory_context = get_memories(user, db_session)
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
@@ -480,7 +480,7 @@ def handle_stream_message_objects(
persona_system_prompt=custom_agent_prompt or "",
token_counter=token_counter,
files=new_msg_req.file_descriptors,
memories=memories,
user_memory_context=user_memory_context,
)
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
@@ -667,7 +667,7 @@ def handle_stream_message_objects(
custom_agent_prompt=custom_agent_prompt,
project_files=extracted_project_files,
persona=persona,
memories=memories,
user_memory_context=user_memory_context,
llm=llm,
token_counter=token_counter,
db_session=db_session,

View File

@@ -4,6 +4,7 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.db.memory import UserMemoryContext
from onyx.db.persona import get_default_behavior_persona
from onyx.db.user_file import calculate_user_files_token_count
from onyx.file_store.models import FileDescriptor
@@ -12,7 +13,6 @@ from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
from onyx.prompts.chat_prompts import USER_INFO_HEADER
from onyx.prompts.prompt_utils import get_company_context
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
@@ -25,6 +25,7 @@ from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
from onyx.prompts.tool_prompts import WEB_SEARCH_SITE_DISABLED_GUIDANCE
from onyx.prompts.user_info import USER_INFORMATION_HEADER
from onyx.tools.interface import Tool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
@@ -52,7 +53,7 @@ def calculate_reserved_tokens(
persona_system_prompt: str,
token_counter: Callable[[str], int],
files: list[FileDescriptor] | None = None,
memories: list[str] | None = None,
user_memory_context: UserMemoryContext | None = None,
) -> int:
"""
Calculate reserved token count for system prompt and user files.
@@ -66,7 +67,7 @@ def calculate_reserved_tokens(
persona_system_prompt: Custom agent system prompt (can be empty string)
token_counter: Function that counts tokens in text
files: List of file descriptors from the chat message (optional)
memories: List of memory strings (optional)
user_memory_context: User memory context (optional)
Returns:
Total reserved token count
@@ -77,7 +78,7 @@ def calculate_reserved_tokens(
fake_system_prompt = build_system_prompt(
base_system_prompt=base_system_prompt,
datetime_aware=True,
memories=memories,
user_memory_context=user_memory_context,
tools=None,
should_cite_documents=True,
include_all_guidance=True,
@@ -133,7 +134,7 @@ def build_reminder_message(
def build_system_prompt(
base_system_prompt: str,
datetime_aware: bool = False,
memories: list[str] | None = None,
user_memory_context: UserMemoryContext | None = None,
tools: Sequence[Tool] | None = None,
should_cite_documents: bool = False,
include_all_guidance: bool = False,
@@ -157,14 +158,15 @@ def build_system_prompt(
)
company_context = get_company_context()
if company_context or memories:
system_prompt += USER_INFO_HEADER
formatted_user_context = (
user_memory_context.as_formatted_prompt() if user_memory_context else ""
)
if company_context or formatted_user_context:
system_prompt += USER_INFORMATION_HEADER
if company_context:
system_prompt += company_context
if memories:
system_prompt += "\n".join(
"- " + memory.strip() for memory in memories if memory.strip()
)
if formatted_user_context:
system_prompt += formatted_user_context
# Append citation guidance after company context if placeholder was not present
# This maintains backward compatibility and ensures citations are always enforced when needed

View File

@@ -75,7 +75,7 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
# Auth Configs
#####
# Upgrades users from disabled auth to basic auth and shows warning.
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
if _auth_type_str == "disabled":
logger.warning(
"AUTH_TYPE='disabled' is no longer supported. "
@@ -900,9 +900,6 @@ MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
# Limit on number of users a free trial tenant can invite (cloud only)
NUM_FREE_TRIAL_USER_INVITES = int(os.environ.get("NUM_FREE_TRIAL_USER_INVITES", "10"))
# Security and authentication
DATA_PLANE_SECRET = os.environ.get(
"DATA_PLANE_SECRET", ""

View File

@@ -158,17 +158,6 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
# How long a queued user-file task is valid before workers discard it.
# Should be longer than the beat interval (20 s) but short enough to prevent
# indefinite queue growth. Workers drop tasks older than this without touching
# the DB, so a shorter value = faster drain of stale duplicates.
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Maximum number of tasks allowed in the user-file-processing queue before the
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
@@ -446,9 +435,6 @@ class OnyxRedisLocks:
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a task is already queued.
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"

View File

@@ -65,7 +65,9 @@ class OnyxDBCredentialsProvider(
f"No credential found: credential={self._credential_id}"
)
return credential.credential_json
if credential.credential_json is None:
return {}
return credential.credential_json.get_value(apply_mask=False)
def set_credentials(self, credential_json: dict[str, Any]) -> None:
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
@@ -81,7 +83,7 @@ class OnyxDBCredentialsProvider(
f"No credential found: credential={self._credential_id}"
)
credential.credential_json = credential_json
credential.credential_json = credential_json # type: ignore[assignment]
db_session.commit()
except Exception:
db_session.rollback()

View File

@@ -118,7 +118,12 @@ def instantiate_connector(
)
connector.set_credentials_provider(provider)
else:
new_credentials = connector.load_credentials(credential.credential_json)
credential_json = (
credential.credential_json.get_value(apply_mask=False)
if credential.credential_json
else {}
)
new_credentials = connector.load_credentials(credential_json)
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)

View File

@@ -32,8 +32,6 @@ class GongConnector(LoadConnector, PollConnector):
BASE_URL = "https://api.gong.io"
MAX_CALL_DETAILS_ATTEMPTS = 6
CALL_DETAILS_DELAY = 30 # in seconds
# Gong API limit is 3 calls/sec — stay safely under it
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
def __init__(
self,
@@ -47,13 +45,9 @@ class GongConnector(LoadConnector, PollConnector):
self.continue_on_fail = continue_on_fail
self.auth_token_basic: str | None = None
self.hide_user_info = hide_user_info
self._last_request_time: float = 0.0
# urllib3 Retry already respects the Retry-After header by default
# (respect_retry_after_header=True), so on 429 it will sleep for the
# duration Gong specifies before retrying.
retry_strategy = Retry(
total=10,
total=5,
backoff_factor=2,
status_forcelist=[429, 500, 502, 503, 504],
)
@@ -67,24 +61,8 @@ class GongConnector(LoadConnector, PollConnector):
url = f"{GongConnector.BASE_URL}{endpoint}"
return url
def _throttled_request(
self, method: str, url: str, **kwargs: Any
) -> requests.Response:
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
now = time.monotonic()
elapsed = now - self._last_request_time
if elapsed < self.MIN_REQUEST_INTERVAL:
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
response = self._session.request(method, url, **kwargs)
self._last_request_time = time.monotonic()
return response
def _get_workspace_id_map(self) -> dict[str, str]:
response = self._throttled_request(
"GET", GongConnector.make_url("/v2/workspaces")
)
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
response.raise_for_status()
workspaces_details = response.json().get("workspaces")
@@ -128,8 +106,8 @@ class GongConnector(LoadConnector, PollConnector):
del body["filter"]["workspaceId"]
while True:
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
response = self._session.post(
GongConnector.make_url("/v2/calls/transcript"), json=body
)
# If no calls in the range, just break out
if response.status_code == 404:
@@ -164,8 +142,8 @@ class GongConnector(LoadConnector, PollConnector):
"contentSelector": {"exposedFields": {"parties": True}},
}
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
response = self._session.post(
GongConnector.make_url("/v2/calls/extensive"), json=body
)
response.raise_for_status()
@@ -216,8 +194,7 @@ class GongConnector(LoadConnector, PollConnector):
# There's a likely race condition in the API where a transcript will have a
# call id but the call to v2/calls/extensive will not return all of the id's
# retry with exponential backoff has been observed to mitigate this
# in ~2 minutes. After max attempts, proceed with whatever we have —
# the per-call loop below will skip missing IDs gracefully.
# in ~2 minutes
current_attempt = 0
while True:
current_attempt += 1
@@ -236,14 +213,11 @@ class GongConnector(LoadConnector, PollConnector):
f"missing_call_ids={missing_call_ids}"
)
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
logger.error(
f"Giving up on missing call id's after "
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
f"missing_call_ids={missing_call_ids}"
f"proceeding with {len(call_details_map)} of "
f"{len(transcript_call_ids)} calls"
raise RuntimeError(
f"Attempt count exceeded for _get_call_details_by_ids: "
f"missing_call_ids={missing_call_ids} "
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
)
break
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
logger.warning(

View File

@@ -6,7 +6,6 @@ from typing import cast
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from onyx.access.models import ExternalAccess
@@ -168,14 +167,6 @@ class DocumentBase(BaseModel):
# list of strings.
metadata: dict[str, str | list[str]]
@field_validator("metadata", mode="before")
@classmethod
def _coerce_metadata_values(cls, v: dict[str, Any]) -> dict[str, str | list[str]]:
return {
key: [str(item) for item in val] if isinstance(val, list) else str(val)
for key, val in v.items()
}
# UTC time
doc_updated_at: datetime | None = None
chunk_count: int | None = None

View File

@@ -270,6 +270,8 @@ def create_credential(
)
db_session.commit()
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
db_session.expire(credential)
return credential
@@ -297,14 +299,21 @@ def alter_credential(
credential.name = name
# Assign a new dictionary to credential.credential_json
credential.credential_json = {
**credential.credential_json,
# Get existing credential_json and merge with new values
existing_json = (
credential.credential_json.get_value(apply_mask=False)
if credential.credential_json
else {}
)
credential.credential_json = { # type: ignore[assignment]
**existing_json,
**credential_json,
}
credential.user_id = user.id
db_session.commit()
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
db_session.expire(credential)
return credential
@@ -318,10 +327,12 @@ def update_credential(
if credential is None:
return None
credential.credential_json = credential_data.credential_json
credential.user_id = user.id
credential.credential_json = credential_data.credential_json # type: ignore[assignment]
credential.user_id = user.id if user is not None else None
db_session.commit()
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
db_session.expire(credential)
return credential
@@ -335,8 +346,10 @@ def update_credential_json(
if credential is None:
return None
credential.credential_json = credential_json
credential.credential_json = credential_json # type: ignore[assignment]
db_session.commit()
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
db_session.expire(credential)
return credential
@@ -346,7 +359,7 @@ def backend_update_credential_json(
db_session: Session,
) -> None:
"""This should not be used in any flows involving the frontend or users"""
credential.credential_json = credential_json
credential.credential_json = credential_json # type: ignore[assignment]
db_session.commit()
@@ -441,7 +454,12 @@ def create_initial_public_credential(db_session: Session) -> None:
)
if first_credential is not None:
if first_credential.credential_json != {} or first_credential.user is not None:
credential_json_value = (
first_credential.credential_json.get_value(apply_mask=False)
if first_credential.credential_json
else {}
)
if credential_json_value != {} or first_credential.user is not None:
raise ValueError(error_msg)
return
@@ -477,8 +495,13 @@ def delete_service_account_credentials(
) -> None:
credentials = fetch_credentials_for_user(db_session=db_session, user=user)
for credential in credentials:
credential_json = (
credential.credential_json.get_value(apply_mask=False)
if credential.credential_json
else {}
)
if (
credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
and credential.source == source
):
db_session.delete(credential)

View File

@@ -111,7 +111,7 @@ def update_federated_connector_oauth_token(
if existing_token:
# Update existing token
existing_token.token = token
existing_token.token = token # type: ignore[assignment]
existing_token.expires_at = expires_at
db_session.commit()
return existing_token
@@ -267,7 +267,13 @@ def update_federated_connector(
# Use provided credentials if updating them, otherwise use existing credentials
# This is needed to instantiate the connector for config validation when only config is being updated
creds_to_use = (
credentials if credentials is not None else federated_connector.credentials
credentials
if credentials is not None
else (
federated_connector.credentials.get_value(apply_mask=False)
if federated_connector.credentials
else {}
)
)
if credentials is not None:
@@ -278,7 +284,7 @@ def update_federated_connector(
raise ValueError(
f"Invalid credentials for federated connector source: {federated_connector.source}"
)
federated_connector.credentials = credentials
federated_connector.credentials = credentials # type: ignore[assignment]
if config is not None:
# Validate config using connector-specific validation

View File

@@ -109,38 +109,45 @@ def can_user_access_llm_provider(
is_admin: If True, bypass user group restrictions but still respect persona restrictions
Access logic:
- is_public controls USER access (group bypass): when True, all users can access
regardless of group membership. When False, user must be in a whitelisted group
(or be admin).
- Persona restrictions are ALWAYS enforced when set, regardless of is_public.
This allows admins to make a provider available to all users while still
restricting which personas (assistants) can use it.
Decision matrix:
1. is_public=True, no personas set → everyone has access
2. is_public=True, personas set → all users, but only whitelisted personas
3. is_public=False, groups+personas set → must satisfy BOTH (admins bypass groups)
4. is_public=False, only groups set → must be in group (admins bypass)
5. is_public=False, only personas set → must use whitelisted persona
6. is_public=False, neither set → admin-only (locked)
1. If is_public=True → everyone has access (public override)
2. If is_public=False:
- Both groups AND personas set → must satisfy BOTH (AND logic, admins bypass group check)
- Only groups set → must be in one of the groups (OR across groups, admins bypass)
- Only personas set → must use one of the personas (OR across personas, applies to admins)
- Neither set → NOBODY has access unless admin (locked, admin-only)
"""
provider_group_ids = {g.id for g in (provider.groups or [])}
provider_persona_ids = {p.id for p in (provider.personas or [])}
has_groups = bool(provider_group_ids)
has_personas = bool(provider_persona_ids)
# Persona restrictions are always enforced when set, regardless of is_public
if has_personas and not (persona and persona.id in provider_persona_ids):
return False
# Public override - everyone has access
if provider.is_public:
return True
# Extract IDs once to avoid multiple iterations
provider_group_ids = (
{group.id for group in provider.groups} if provider.groups else set()
)
provider_persona_ids = (
{p.id for p in provider.personas} if provider.personas else set()
)
has_groups = bool(provider_group_ids)
has_personas = bool(provider_persona_ids)
# Both groups AND personas set → AND logic (must satisfy both)
if has_groups and has_personas:
# Admins bypass group check but still must satisfy persona restrictions
user_in_group = is_admin or bool(user_group_ids & provider_group_ids)
persona_allowed = persona.id in provider_persona_ids if persona else False
return user_in_group and persona_allowed
# Only groups set → user must be in one of the groups (admins bypass)
if has_groups:
return is_admin or bool(user_group_ids & provider_group_ids)
# No groups: either persona-whitelisted (already passed) or admin-only if locked
return has_personas or is_admin
# Only personas set → persona must be in allowed list (applies to admins too)
if has_personas:
return persona.id in provider_persona_ids if persona else False
# Neither groups nor personas set, and not public → admins can access
return is_admin
def validate_persona_ids_exist(
@@ -225,7 +232,8 @@ def upsert_llm_provider(
custom_config = custom_config or None
existing_llm_provider.provider = llm_provider_upsert_request.provider
existing_llm_provider.api_key = llm_provider_upsert_request.api_key
# EncryptedString accepts str for writes, returns SensitiveValue for reads
existing_llm_provider.api_key = llm_provider_upsert_request.api_key # type: ignore[assignment]
existing_llm_provider.api_base = llm_provider_upsert_request.api_base
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
existing_llm_provider.custom_config = custom_config
@@ -421,7 +429,7 @@ def fetch_existing_models(
def fetch_existing_llm_providers(
db_session: Session,
flow_type_filter: list[LLMModelFlowType],
flow_types: list[LLMModelFlowType],
only_public: bool = False,
exclude_image_generation_providers: bool = True,
) -> list[LLMProviderModel]:
@@ -429,27 +437,30 @@ def fetch_existing_llm_providers(
Args:
db_session: Database session
flow_type_filter: List of flow types to filter by, empty list for no filter
flow_types: List of flow types to filter by
only_public: If True, only return public providers
exclude_image_generation_providers: If True, exclude providers that are
used for image generation configs
"""
stmt = select(LLMProviderModel)
if flow_type_filter:
providers_with_flows = (
select(ModelConfiguration.llm_provider_id)
.join(LLMModelFlow)
.where(LLMModelFlow.llm_model_flow_type.in_(flow_type_filter))
.distinct()
)
stmt = stmt.where(LLMProviderModel.id.in_(providers_with_flows))
providers_with_flows = (
select(ModelConfiguration.llm_provider_id)
.join(LLMModelFlow)
.where(LLMModelFlow.llm_model_flow_type.in_(flow_types))
.distinct()
)
if exclude_image_generation_providers:
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
)
else:
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
ImageGenerationConfig
)
stmt = stmt.where(~LLMProviderModel.id.in_(image_gen_provider_ids))
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
| LLMProviderModel.id.in_(image_gen_provider_ids)
)
stmt = stmt.options(
selectinload(LLMProviderModel.model_configurations),
@@ -712,15 +723,13 @@ def sync_auto_mode_models(
changes += 1
else:
# Add new model - all models from GitHub config are visible
insert_new_model_configuration__no_commit(
db_session=db_session,
new_model = ModelConfiguration(
llm_provider_id=provider.id,
model_name=model_config.name,
supported_flows=[LLMModelFlowType.CHAT],
is_visible=True,
max_input_tokens=None,
name=model_config.name,
display_name=model_config.display_name,
is_visible=True,
)
db_session.add(new_model)
changes += 1
# In Auto mode, default model is always set from GitHub config

View File

@@ -19,6 +19,7 @@ from onyx.db.models import Tool
from onyx.db.models import User
from onyx.server.features.mcp.models import MCPConnectionData
from onyx.utils.logger import setup_logger
from onyx.utils.sensitive import SensitiveValue
logger = setup_logger()
@@ -204,6 +205,21 @@ def remove_user_from_mcp_server(
# MCPConnectionConfig operations
def extract_connection_data(
config: MCPConnectionConfig | None, apply_mask: bool = False
) -> MCPConnectionData:
"""Extract MCPConnectionData from a connection config, with proper typing.
This helper encapsulates the cast from the JSON column's dict[str, Any]
to the typed MCPConnectionData structure.
"""
if config is None or config.config is None:
return MCPConnectionData(headers={})
if isinstance(config.config, SensitiveValue):
return cast(MCPConnectionData, config.config.get_value(apply_mask=apply_mask))
return cast(MCPConnectionData, config.config)
def get_connection_config_by_id(
config_id: int, db_session: Session
) -> MCPConnectionConfig:
@@ -269,7 +285,7 @@ def update_connection_config(
config = get_connection_config_by_id(config_id, db_session)
if config_data is not None:
config.config = config_data
config.config = config_data # type: ignore[assignment]
# Force SQLAlchemy to detect the change by marking the field as modified
flag_modified(config, "config")
@@ -287,7 +303,7 @@ def upsert_user_connection_config(
existing_config = get_user_connection_config(server_id, user_email, db_session)
if existing_config:
existing_config.config = config_data
existing_config.config = config_data # type: ignore[assignment]
db_session.flush() # Don't commit yet, let caller decide when to commit
return existing_config
else:

View File

@@ -1,22 +1,111 @@
from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import Memory
from onyx.db.models import User
from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT
from onyx.prompts.user_info import USER_MEMORIES_PROMPT
from onyx.prompts.user_info import USER_PREFERENCES_PROMPT
from onyx.prompts.user_info import USER_ROLE_PROMPT
def get_memories(user: User, db_session: Session) -> list[str]:
class UserInfo(BaseModel):
name: str | None = None
role: str | None = None
email: str | None = None
def to_dict(self) -> dict:
return {
"name": self.name,
"role": self.role,
"email": self.email,
}
class UserMemoryContext(BaseModel):
model_config = ConfigDict(frozen=True)
user_info: UserInfo
user_preferences: str | None = None
memories: tuple[str, ...] = ()
def as_formatted_list(self) -> list[str]:
"""Returns combined list of user info, preferences, and memories."""
result = []
if self.user_info.name:
result.append(f"User's name: {self.user_info.name}")
if self.user_info.role:
result.append(f"User's role: {self.user_info.role}")
if self.user_info.email:
result.append(f"User's email: {self.user_info.email}")
if self.user_preferences:
result.append(f"User preferences: {self.user_preferences}")
result.extend(self.memories)
return result
def as_formatted_prompt(self) -> str:
"""Returns structured prompt sections for the system prompt."""
has_basic_info = (
self.user_info.name or self.user_info.email or self.user_info.role
)
if not has_basic_info and not self.user_preferences and not self.memories:
return ""
sections: list[str] = []
if has_basic_info:
role_line = (
USER_ROLE_PROMPT.format(user_role=self.user_info.role).strip()
if self.user_info.role
else ""
)
if role_line:
role_line = "\n" + role_line
sections.append(
BASIC_INFORMATION_PROMPT.format(
user_name=self.user_info.name or "",
user_email=self.user_info.email or "",
user_role=role_line,
)
)
if self.user_preferences:
sections.append(
USER_PREFERENCES_PROMPT.format(user_preferences=self.user_preferences)
)
if self.memories:
formatted_memories = "\n".join(f"- {memory}" for memory in self.memories)
sections.append(
USER_MEMORIES_PROMPT.format(user_memories=formatted_memories)
)
return "".join(sections)
def get_memories(user: User, db_session: Session) -> UserMemoryContext:
if not user.use_memories:
return []
return UserMemoryContext(user_info=UserInfo())
user_info = [
f"User's name: {user.personal_name}" if user.personal_name else "",
f"User's role: {user.personal_role}" if user.personal_role else "",
f"User's email: {user.email}" if user.email else "",
]
user_info = UserInfo(
name=user.personal_name,
role=user.personal_role,
email=user.email,
)
user_preferences = None
if user.user_preferences:
user_preferences = user.user_preferences
memory_rows = db_session.scalars(
select(Memory).where(Memory.user_id == user.id)
select(Memory).where(Memory.user_id == user.id).order_by(Memory.id.asc())
).all()
memories = [memory.memory_text for memory in memory_rows if memory.memory_text]
return user_info + memories
memories = tuple(memory.memory_text for memory in memory_rows if memory.memory_text)
return UserMemoryContext(
user_info=user_info,
user_preferences=user_preferences,
memories=memories,
)

View File

@@ -95,10 +95,10 @@ from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.kg.models import KGStage
from onyx.server.features.mcp.models import MCPConnectionData
from onyx.tools.tool_implementations.web_search.models import WebContentProviderConfig
from onyx.utils.encryption import decrypt_bytes_to_string
from onyx.utils.encryption import encrypt_string_to_bytes
from onyx.utils.sensitive import SensitiveValue
from onyx.utils.headers import HeaderItemDict
from shared_configs.enums import EmbeddingProvider
from onyx.context.search.enums import RecencyBiasSetting
@@ -122,18 +122,35 @@ class EncryptedString(TypeDecorator):
cache_ok = True
def process_bind_param(
self, value: str | None, dialect: Dialect # noqa: ARG002
self, value: str | SensitiveValue[str] | None, dialect: Dialect # noqa: ARG002
) -> bytes | None:
if value is not None:
# Handle both raw strings and SensitiveValue wrappers
if isinstance(value, SensitiveValue):
# Get raw value for storage
value = value.get_value(apply_mask=False)
return encrypt_string_to_bytes(value)
return value
def process_result_value(
self, value: bytes | None, dialect: Dialect # noqa: ARG002
) -> str | None:
) -> SensitiveValue[str] | None:
if value is not None:
return decrypt_bytes_to_string(value)
return value
return SensitiveValue(
encrypted_bytes=value,
decrypt_fn=decrypt_bytes_to_string,
is_json=False,
)
return None
def compare_values(self, x: Any, y: Any) -> bool:
if x is None or y is None:
return x == y
if isinstance(x, SensitiveValue):
x = x.get_value(apply_mask=False)
if isinstance(y, SensitiveValue):
y = y.get_value(apply_mask=False)
return x == y
class EncryptedJson(TypeDecorator):
@@ -142,20 +159,38 @@ class EncryptedJson(TypeDecorator):
cache_ok = True
def process_bind_param(
self, value: dict | None, dialect: Dialect # noqa: ARG002
self,
value: dict[str, Any] | SensitiveValue[dict[str, Any]] | None,
dialect: Dialect, # noqa: ARG002
) -> bytes | None:
if value is not None:
# Handle both raw dicts and SensitiveValue wrappers
if isinstance(value, SensitiveValue):
# Get raw value for storage
value = value.get_value(apply_mask=False)
json_str = json.dumps(value)
return encrypt_string_to_bytes(json_str)
return value
def process_result_value(
self, value: bytes | None, dialect: Dialect # noqa: ARG002
) -> dict | None:
) -> SensitiveValue[dict[str, Any]] | None:
if value is not None:
json_str = decrypt_bytes_to_string(value)
return json.loads(json_str)
return value
return SensitiveValue(
encrypted_bytes=value,
decrypt_fn=decrypt_bytes_to_string,
is_json=True,
)
return None
def compare_values(self, x: Any, y: Any) -> bool:
if x is None or y is None:
return x == y
if isinstance(x, SensitiveValue):
x = x.get_value(apply_mask=False)
if isinstance(y, SensitiveValue):
y = y.get_value(apply_mask=False)
return x == y
class NullFilteredString(TypeDecorator):
@@ -216,6 +251,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
personal_name: Mapped[str | None] = mapped_column(String, nullable=True)
personal_role: Mapped[str | None] = mapped_column(String, nullable=True)
use_memories: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
user_preferences: Mapped[str | None] = mapped_column(Text, nullable=True)
chosen_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
@@ -1755,7 +1791,9 @@ class Credential(Base):
)
id: Mapped[int] = mapped_column(primary_key=True)
credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson())
credential_json: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
EncryptedJson()
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
@@ -1793,7 +1831,9 @@ class FederatedConnector(Base):
source: Mapped[FederatedConnectorSource] = mapped_column(
Enum(FederatedConnectorSource, native_enum=False)
)
credentials: Mapped[dict[str, str]] = mapped_column(EncryptedJson(), nullable=False)
credentials: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
EncryptedJson(), nullable=False
)
config: Mapped[dict[str, Any]] = mapped_column(
postgresql.JSONB(), default=dict, nullable=False, server_default="{}"
)
@@ -1820,7 +1860,9 @@ class FederatedConnectorOAuthToken(Base):
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)
token: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
token: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=False
)
expires_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime, nullable=True
)
@@ -1964,7 +2006,9 @@ class SearchSettings(Base):
@property
def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider is not None else None
if self.cloud_provider is None or self.cloud_provider.api_key is None:
return None
return self.cloud_provider.api_key.get_value(apply_mask=False)
@property
def large_chunks_enabled(self) -> bool:
@@ -2726,7 +2770,9 @@ class LLMProvider(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
provider: Mapped[str] = mapped_column(String)
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=True
)
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
# custom configs that should be passed to the LLM provider at inference time
@@ -2879,7 +2925,7 @@ class CloudEmbeddingProvider(Base):
Enum(EmbeddingProvider), primary_key=True
)
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(EncryptedString())
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
@@ -2898,7 +2944,9 @@ class InternetSearchProvider(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
provider_type: Mapped[str] = mapped_column(String, nullable=False)
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=True
)
config: Mapped[dict[str, str] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
@@ -2920,7 +2968,9 @@ class InternetContentProvider(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
provider_type: Mapped[str] = mapped_column(String, nullable=False)
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=True
)
config: Mapped[WebContentProviderConfig | None] = mapped_column(
PydanticType(WebContentProviderConfig), nullable=True
)
@@ -3064,8 +3114,12 @@ class OAuthConfig(Base):
token_url: Mapped[str] = mapped_column(Text, nullable=False)
# Client credentials (encrypted)
client_id: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
client_secret: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
client_id: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=False
)
client_secret: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=False
)
# Optional configurations
scopes: Mapped[list[str] | None] = mapped_column(postgresql.JSONB(), nullable=True)
@@ -3112,7 +3166,9 @@ class OAuthUserToken(Base):
# "expires_at": 1234567890, # Unix timestamp, optional
# "scope": "repo user" # Optional
# }
token_data: Mapped[dict[str, Any]] = mapped_column(EncryptedJson(), nullable=False)
token_data: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
EncryptedJson(), nullable=False
)
# Metadata
created_at: Mapped[datetime.datetime] = mapped_column(
@@ -3445,9 +3501,15 @@ class SlackBot(Base):
name: Mapped[str] = mapped_column(String)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
bot_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
app_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
user_token: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
bot_token: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), unique=True
)
app_token: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), unique=True
)
user_token: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=True
)
slack_channel_configs: Mapped[list[SlackChannelConfig]] = relationship(
"SlackChannelConfig",
@@ -3468,7 +3530,9 @@ class DiscordBotConfig(Base):
id: Mapped[str] = mapped_column(
String, primary_key=True, server_default=text("'SINGLETON'")
)
bot_token: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
bot_token: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=False
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
@@ -3624,7 +3688,9 @@ class KVStore(Base):
key: Mapped[str] = mapped_column(String, primary_key=True)
value: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
encrypted_value: Mapped[JSON_ro] = mapped_column(EncryptedJson(), nullable=True)
encrypted_value: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
EncryptedJson(), nullable=True
)
class FileRecord(Base):
@@ -4344,7 +4410,7 @@ class MCPConnectionConfig(Base):
# "registration_access_token": "<token>", # For managing registration
# "registration_client_uri": "<uri>", # For managing registration
# }
config: Mapped[MCPConnectionData] = mapped_column(
config: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
EncryptedJson(), nullable=False, default=dict
)

View File

@@ -87,13 +87,13 @@ def update_oauth_config(
if token_url is not None:
oauth_config.token_url = token_url
if clear_client_id:
oauth_config.client_id = ""
oauth_config.client_id = "" # type: ignore[assignment]
elif client_id is not None:
oauth_config.client_id = client_id
oauth_config.client_id = client_id # type: ignore[assignment]
if clear_client_secret:
oauth_config.client_secret = ""
oauth_config.client_secret = "" # type: ignore[assignment]
elif client_secret is not None:
oauth_config.client_secret = client_secret
oauth_config.client_secret = client_secret # type: ignore[assignment]
if scopes is not None:
oauth_config.scopes = scopes
if additional_params is not None:
@@ -154,7 +154,7 @@ def upsert_user_oauth_token(
if existing_token:
# Update existing token
existing_token.token_data = token_data
existing_token.token_data = token_data # type: ignore[assignment]
db_session.commit()
return existing_token
else:

View File

@@ -43,9 +43,9 @@ def update_slack_bot(
# update the app
slack_bot.name = name
slack_bot.enabled = enabled
slack_bot.bot_token = bot_token
slack_bot.app_token = app_token
slack_bot.user_token = user_token
slack_bot.bot_token = bot_token # type: ignore[assignment]
slack_bot.app_token = app_token # type: ignore[assignment]
slack_bot.user_token = user_token # type: ignore[assignment]
db_session.commit()

View File

@@ -160,6 +160,7 @@ def update_user_personalization(
personal_role: str | None,
use_memories: bool,
memories: list[str],
user_preferences: str | None,
db_session: Session,
) -> None:
db_session.execute(
@@ -169,6 +170,7 @@ def update_user_personalization(
personal_name=personal_name,
personal_role=personal_role,
use_memories=use_memories,
user_preferences=user_preferences,
)
)

View File

@@ -73,7 +73,8 @@ def _apply_search_provider_updates(
provider.provider_type = provider_type.value
provider.config = config
if api_key_changed or provider.api_key is None:
provider.api_key = api_key
# EncryptedString accepts str for writes, returns SensitiveValue for reads
provider.api_key = api_key # type: ignore[assignment]
def upsert_web_search_provider(
@@ -228,7 +229,8 @@ def _apply_content_provider_updates(
provider.provider_type = provider_type.value
provider.config = config
if api_key_changed or provider.api_key is None:
provider.api_key = api_key
# EncryptedString accepts str for writes, returns SensitiveValue for reads
provider.api_key = api_key # type: ignore[assignment]
def upsert_web_content_provider(

View File

@@ -119,7 +119,16 @@ def get_federated_retrieval_functions(
federated_retrieval_infos_slack = []
# Use user_token if available, otherwise fall back to bot_token
access_token = tenant_slack_bot.user_token or tenant_slack_bot.bot_token
# Unwrap SensitiveValue for backend API calls
access_token = (
tenant_slack_bot.user_token.get_value(apply_mask=False)
if tenant_slack_bot.user_token
else (
tenant_slack_bot.bot_token.get_value(apply_mask=False)
if tenant_slack_bot.bot_token
else ""
)
)
if not tenant_slack_bot.user_token:
logger.warning(
f"Using bot_token for Slack search (limited functionality): {tenant_slack_bot.name}"
@@ -138,7 +147,12 @@ def get_federated_retrieval_functions(
)
# Capture variables by value to avoid lambda closure issues
bot_token = tenant_slack_bot.bot_token
# Unwrap SensitiveValue for backend API calls
bot_token = (
tenant_slack_bot.bot_token.get_value(apply_mask=False)
if tenant_slack_bot.bot_token
else ""
)
# Use connector config for channel filtering (guaranteed to exist at this point)
connector_entities = slack_federated_connector_config
@@ -252,11 +266,11 @@ def get_federated_retrieval_functions(
connector = get_federated_connector(
oauth_token.federated_connector.source,
oauth_token.federated_connector.credentials,
oauth_token.federated_connector.credentials.get_value(apply_mask=False),
)
# Capture variables by value to avoid lambda closure issues
access_token = oauth_token.token
access_token = oauth_token.token.get_value(apply_mask=False)
def create_retrieval_function(
conn: FederatedConnector,

View File

@@ -43,7 +43,7 @@ class PgRedisKVStore(KeyValueStore):
obj = db_session.query(KVStore).filter_by(key=key).first()
if obj:
obj.value = plain_val
obj.encrypted_value = encrypted_val
obj.encrypted_value = encrypted_val # type: ignore[assignment]
else:
obj = KVStore(key=key, value=plain_val, encrypted_value=encrypted_val)
db_session.query(KVStore).filter_by(key=key).delete() # just in case
@@ -73,7 +73,8 @@ class PgRedisKVStore(KeyValueStore):
if obj.value is not None:
value = obj.value
elif obj.encrypted_value is not None:
value = obj.encrypted_value
# Unwrap SensitiveValue - this is internal backend use
value = obj.encrypted_value.get_value(apply_mask=False)
else:
value = None

View File

@@ -1,4 +1,8 @@
import os
import threading
from collections.abc import Iterator
from contextlib import contextmanager
from contextlib import nullcontext
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
@@ -44,11 +48,13 @@ from onyx.llm.well_known_providers.constants import (
VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT,
)
from onyx.llm.well_known_providers.constants import VERTEX_LOCATION_KWARG
from onyx.server.utils import mask_string
from onyx.utils.encryption import mask_string
from onyx.utils.logger import setup_logger
logger = setup_logger()
_env_lock = threading.Lock()
if TYPE_CHECKING:
from litellm import CustomStreamWrapper
from litellm import HTTPHandler
@@ -378,23 +384,29 @@ class LitellmLLM(LLM):
if "api_key" not in passthrough_kwargs:
passthrough_kwargs["api_key"] = self._api_key or None
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
messages=_prompt_to_dicts(prompt),
tools=tools,
tool_choice=tool_choice,
stream=stream,
temperature=temperature,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
client=client,
**optional_kwargs,
**passthrough_kwargs,
env_ctx = (
temporary_env_and_lock(self._custom_config)
if self._custom_config
else nullcontext()
)
with env_ctx:
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
messages=_prompt_to_dicts(prompt),
tools=tools,
tool_choice=tool_choice,
stream=stream,
temperature=temperature,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
client=client,
**optional_kwargs,
**passthrough_kwargs,
)
return response
except Exception as e:
# for break pointing
@@ -475,22 +487,53 @@ class LitellmLLM(LLM):
client = HTTPHandler(timeout=timeout_override or self._timeout)
try:
response = cast(
LiteLLMModelResponse,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=False,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
parallel_tool_calls=True,
reasoning_effort=reasoning_effort,
user_identity=user_identity,
client=client,
),
)
if self._custom_config:
# When custom_config is set, env vars are temporarily injected
# under a global lock. Using stream=True here means the lock is
# only held during connection setup (not the full inference).
# The chunks are then collected outside the lock and reassembled
# into a single ModelResponse via stream_chunk_builder.
from litellm import stream_chunk_builder
from litellm import CustomStreamWrapper as LiteLLMCustomStreamWrapper
stream_response = cast(
LiteLLMCustomStreamWrapper,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=True,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
parallel_tool_calls=True,
reasoning_effort=reasoning_effort,
user_identity=user_identity,
client=client,
),
)
chunks = list(stream_response)
response = cast(
LiteLLMModelResponse,
stream_chunk_builder(chunks),
)
else:
response = cast(
LiteLLMModelResponse,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=False,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
parallel_tool_calls=True,
reasoning_effort=reasoning_effort,
user_identity=user_identity,
client=client,
),
)
model_response = from_litellm_model_response(response)
@@ -581,3 +624,29 @@ class LitellmLLM(LLM):
finally:
if client is not None:
client.close()
@contextmanager
def temporary_env_and_lock(env_variables: dict[str, str]) -> Iterator[None]:
"""
Temporarily sets the environment variables to the given values.
Code path is locked while the environment variables are set.
Then cleans up the environment and frees the lock.
"""
with _env_lock:
logger.debug("Acquired lock in temporary_env_and_lock")
# Store original values (None if key didn't exist)
original_values: dict[str, str | None] = {
key: os.environ.get(key) for key in env_variables
}
try:
os.environ.update(env_variables)
yield
finally:
for key, original_value in original_values.items():
if original_value is None:
os.environ.pop(key, None) # Remove if it didn't exist before
else:
os.environ[key] = original_value # Restore original value
logger.debug("Released lock in temporary_env_and_lock")

View File

@@ -4,6 +4,7 @@ from onyx.configs.constants import AuthType
from onyx.db.discord_bot import get_discord_bot_config
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.utils.logger import setup_logger
from onyx.utils.sensitive import SensitiveValue
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -36,4 +37,8 @@ def get_bot_token() -> str | None:
except Exception as e:
logger.error(f"Failed to get bot token from database: {e}")
return None
return config.bot_token if config else None
if config and config.bot_token:
if isinstance(config.bot_token, SensitiveValue):
return config.bot_token.get_value(apply_mask=False)
return config.bot_token
return None

View File

@@ -592,8 +592,11 @@ def build_slack_response_blocks(
)
citations_blocks = []
document_blocks = []
if answer.citation_info:
citations_blocks = _build_citations_blocks(answer)
else:
document_blocks = _priority_ordered_documents_blocks(answer)
citations_divider = [DividerBlock()] if citations_blocks else []
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
@@ -605,6 +608,7 @@ def build_slack_response_blocks(
+ ai_feedback_block
+ citations_divider
+ citations_blocks
+ document_blocks
+ buttons_divider
+ web_follow_up_block
+ follow_up_block

View File

@@ -1,9 +1,5 @@
import re
from enum import Enum
# Matches Slack channel references like <#C097NBWMY8Y> or <#C097NBWMY8Y|channel-name>
SLACK_CHANNEL_REF_PATTERN = re.compile(r"<#([A-Z0-9]+)(?:\|([^>]+))?>")
LIKE_BLOCK_ACTION_ID = "feedback-like"
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
SHOW_EVERYONE_ACTION_ID = "show-everyone"

View File

@@ -1,163 +1,29 @@
import re
from collections.abc import Callable
from typing import Any
from mistune import create_markdown
from mistune import HTMLRenderer
# Tags that should be replaced with a newline (line-break and block-level elements)
_HTML_NEWLINE_TAG_PATTERN = re.compile(
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
re.IGNORECASE,
)
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
_HTML_TAG_PATTERN = re.compile(
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
)
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
# Matches the start of any markdown link: [text]( or [[n]](
# The inner group handles nested brackets for citation links like [[1]](.
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
# Mistune doesn't recognise this syntax, so text() would escape the angle
# brackets and Slack would render them as literal text instead of links.
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
def _sanitize_html(text: str) -> str:
"""Strip HTML tags from a text fragment.
Block-level closing tags and <br> are converted to newlines.
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
"""
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
text = _HTML_TAG_PATTERN.sub("", text)
return text
def _transform_outside_code_blocks(
message: str, transform: Callable[[str], str]
) -> str:
"""Apply *transform* only to text outside fenced code blocks."""
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
result: list[str] = []
for i, part in enumerate(parts):
result.append(transform(part))
if i < len(code_blocks):
result.append(code_blocks[i])
return "".join(result)
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
"""Extract markdown link destination, allowing nested parentheses in the URL."""
depth = 0
i = start_idx
while i < len(message):
curr = message[i]
if curr == "\\":
i += 2
continue
if curr == "(":
depth += 1
elif curr == ")":
if depth == 0:
return message[start_idx:i], i
depth -= 1
i += 1
return message[start_idx:], None
def _normalize_link_destinations(message: str) -> str:
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
Markdown link syntax [text](url) breaks when the URL contains unescaped
parentheses, spaces, or other special characters. Wrapping the URL in angle
brackets — [text](<url>) — tells the parser to treat everything inside as
a literal URL. This applies to all links, not just citations.
"""
if "](" not in message:
return message
normalized_parts: list[str] = []
cursor = 0
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
normalized_parts.append(message[cursor : match.end()])
destination_start = match.end()
destination, end_idx = _extract_link_destination(message, destination_start)
if end_idx is None:
normalized_parts.append(message[destination_start:])
return "".join(normalized_parts)
already_wrapped = destination.startswith("<") and destination.endswith(">")
if destination and not already_wrapped:
destination = f"<{destination}>"
normalized_parts.append(destination)
normalized_parts.append(")")
cursor = end_idx + 1
normalized_parts.append(message[cursor:])
return "".join(normalized_parts)
def _convert_slack_links_to_markdown(message: str) -> str:
"""Convert Slack-style <url|text> links to standard markdown [text](url).
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
recognise it, so the angle brackets would be escaped by text() and Slack
would render the link as literal text instead of a clickable link.
"""
return _transform_outside_code_blocks(
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
)
def format_slack_message(message: str | None) -> str:
if message is None:
return ""
message = _transform_outside_code_blocks(message, _sanitize_html)
message = _convert_slack_links_to_markdown(message)
normalized_message = _normalize_link_destinations(message)
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
result = md(normalized_message)
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
result = md(message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
return result.rstrip("\n")
return result
class SlackRenderer(HTMLRenderer):
"""Renders markdown as Slack mrkdwn format instead of HTML.
Overrides all HTMLRenderer methods that produce HTML tags to ensure
no raw HTML ever appears in Slack messages.
"""
SPECIALS: dict[str, str] = {"&": "&amp;", "<": "&lt;", ">": "&gt;"}
def __init__(self) -> None:
super().__init__()
self._table_headers: list[str] = []
self._current_row_cells: list[str] = []
def escape_special(self, text: str) -> str:
for special, replacement in self.SPECIALS.items():
text = text.replace(special, replacement)
return text
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
return f"*{text}*\n\n"
return f"*{text}*\n"
def emphasis(self, text: str) -> str:
return f"_{text}_"
@@ -176,7 +42,7 @@ class SlackRenderer(HTMLRenderer):
count += 1
prefix = f"{count}. " if ordered else ""
lines[i] = f"{prefix}{line[4:]}"
return "\n".join(lines) + "\n"
return "\n".join(lines)
def list_item(self, text: str) -> str:
return f"li: {text}\n"
@@ -198,73 +64,7 @@ class SlackRenderer(HTMLRenderer):
return f"`{text}`"
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
def linebreak(self) -> str:
return "\n"
def thematic_break(self) -> str:
return "---\n\n"
def block_quote(self, text: str) -> str:
lines = text.strip().split("\n")
quoted = "\n".join(f">{line}" for line in lines)
return quoted + "\n\n"
def block_html(self, html: str) -> str:
return _sanitize_html(html) + "\n\n"
def block_error(self, text: str) -> str:
return f"```\n{text}\n```\n\n"
def text(self, text: str) -> str:
# Only escape the three entities Slack recognizes: & < >
# HTMLRenderer.text() also escapes " to &quot; which Slack renders
# as literal &quot; text since Slack doesn't recognize that entity.
return self.escape_special(text)
# -- Table rendering (converts markdown tables to vertical cards) --
def table_cell(
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
) -> str:
if head:
self._table_headers.append(text.strip())
else:
self._current_row_cells.append(text.strip())
return ""
def table_head(self, text: str) -> str: # noqa: ARG002
self._current_row_cells = []
return ""
def table_row(self, text: str) -> str: # noqa: ARG002
cells = self._current_row_cells
self._current_row_cells = []
# First column becomes the bold title, remaining columns are bulleted fields
lines: list[str] = []
if cells:
title = cells[0]
if title:
# Avoid double-wrapping if cell already contains bold markup
if title.startswith("*") and title.endswith("*") and len(title) > 1:
lines.append(title)
else:
lines.append(f"*{title}*")
for i, cell in enumerate(cells[1:], start=1):
if i < len(self._table_headers):
lines.append(f"{self._table_headers[i]}: {cell}")
else:
lines.append(f"{cell}")
return "\n".join(lines) + "\n\n"
def table_body(self, text: str) -> str:
return text
def table(self, text: str) -> str:
self._table_headers = []
self._current_row_cells = []
return text + "\n"
return f"```\n{code}\n```\n"
def paragraph(self, text: str) -> str:
return f"{text}\n\n"
return f"{text}\n"

View File

@@ -18,18 +18,15 @@ from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS
from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import Tag
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import SlackChannelConfig
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.db.users import get_user_by_email
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.models import ThreadMessage
from onyx.onyxbot.slack.utils import get_channel_from_id
from onyx.onyxbot.slack.utils import get_channel_name_from_id
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import SlackRateLimiter
@@ -44,51 +41,6 @@ srl = SlackRateLimiter()
RT = TypeVar("RT") # return type
def resolve_channel_references(
message: str,
client: WebClient,
logger: OnyxLoggingAdapter,
) -> tuple[str, list[Tag]]:
"""Parse Slack channel references from a message, resolve IDs to names,
replace the raw markup with readable #channel-name, and return channel tags
for search filtering."""
tags: list[Tag] = []
channel_matches = SLACK_CHANNEL_REF_PATTERN.findall(message)
seen_channel_ids: set[str] = set()
for channel_id, channel_name_from_markup in channel_matches:
if channel_id in seen_channel_ids:
continue
seen_channel_ids.add(channel_id)
channel_name = channel_name_from_markup or None
if not channel_name:
try:
channel_info = get_channel_from_id(client=client, channel_id=channel_id)
channel_name = channel_info.get("name") or None
except Exception:
logger.warning(f"Failed to resolve channel name for ID: {channel_id}")
if not channel_name:
continue
# Replace raw Slack markup with readable channel name
if channel_name_from_markup:
message = message.replace(
f"<#{channel_id}|{channel_name_from_markup}>",
f"#{channel_name}",
)
else:
message = message.replace(
f"<#{channel_id}>",
f"#{channel_name}",
)
tags.append(Tag(tag_key="Channel", tag_value=channel_name))
return message, tags
def rate_limits(
client: WebClient, channel: str, thread_ts: Optional[str]
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
@@ -205,20 +157,6 @@ def handle_regular_answer(
user_message = messages[-1]
history_messages = messages[:-1]
# Resolve any <#CHANNEL_ID> references in the user message to readable
# channel names and extract channel tags for search filtering
resolved_message, channel_tags = resolve_channel_references(
message=user_message.message,
client=client,
logger=logger,
)
user_message = ThreadMessage(
message=resolved_message,
sender=user_message.sender,
role=user_message.role,
)
channel_name, _ = get_channel_name_from_id(
client=client,
channel_id=channel,
@@ -269,7 +207,6 @@ def handle_regular_answer(
source_type=None,
document_set=document_set_names,
time_cutoff=None,
tags=channel_tags if channel_tags else None,
)
new_message_request = SendMessageRequest(
@@ -294,16 +231,6 @@ def handle_regular_answer(
slack_context_str=slack_context_str,
)
# If a channel filter was applied but no results were found, override
# the LLM response to avoid hallucinated answers about unindexed channels
if channel_tags and not answer.citation_info and not answer.top_documents:
channel_names = ", ".join(f"#{tag.tag_value}" for tag in channel_tags)
answer.answer = (
f"No indexed data found for {channel_names}. "
"This channel may not be indexed, or there may be no messages "
"matching your query within it."
)
except Exception as e:
logger.exception(
f"Unable to process message - did not successfully answer "
@@ -358,7 +285,6 @@ def handle_regular_answer(
only_respond_if_citations
and not answer.citation_info
and not message_info.bypass_filters
and not channel_tags
):
logger.error(
f"Unable to find citations to answer: '{answer.answer}' - not answering!"

View File

@@ -216,14 +216,10 @@ class SlackbotHandler:
- If the tokens have changed, close the existing socket client and reconnect.
- If the tokens are new, warm up the model and start a new socket client.
"""
slack_bot_tokens = SlackBotTokens(
bot_token=bot.bot_token,
app_token=bot.app_token,
)
tenant_bot_pair = (tenant_id, bot.id)
# If the tokens are missing or empty, close the socket client and remove them.
if not slack_bot_tokens:
if not bot.bot_token or not bot.app_token:
logger.debug(
f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}"
)
@@ -233,6 +229,11 @@ class SlackbotHandler:
del self.slack_bot_tokens[tenant_bot_pair]
return
slack_bot_tokens = SlackBotTokens(
bot_token=bot.bot_token.get_value(apply_mask=False),
app_token=bot.app_token.get_value(apply_mask=False),
)
tokens_exist = tenant_bot_pair in self.slack_bot_tokens
tokens_changed = (
tokens_exist and slack_bot_tokens != self.slack_bot_tokens[tenant_bot_pair]

View File

@@ -25,9 +25,6 @@ You can use Markdown tables to format your responses for data, lists, and other
""".lstrip()
# Section for information about the user if provided such as their name, role, memories, etc.
USER_INFO_HEADER = "\n\n# User Information\n"
COMPANY_NAME_BLOCK = """
The user is at an organization called `{company_name}`.
"""

View File

@@ -109,7 +109,6 @@ class TenantRedis(redis.Redis):
"unlock",
"get",
"set",
"setex",
"delete",
"exists",
"incrby",

View File

@@ -92,7 +92,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pairs_for
from onyx.db.connector_credential_pair import (
get_connector_credential_pairs_for_user_parallel,
)
from onyx.db.connector_credential_pair import verify_user_has_access_to_cc_pair
from onyx.db.credentials import cleanup_gmail_credentials
from onyx.db.credentials import cleanup_google_drive_credentials
from onyx.db.credentials import create_credential
@@ -404,12 +403,13 @@ def check_drive_tokens(
db_session: Session = Depends(get_session),
) -> AuthStatus:
db_credentials = fetch_credential_by_id_for_user(credential_id, user, db_session)
if (
not db_credentials
or DB_CREDENTIALS_DICT_TOKEN_KEY not in db_credentials.credential_json
):
if not db_credentials or not db_credentials.credential_json:
return AuthStatus(authenticated=False)
token_json_str = str(db_credentials.credential_json[DB_CREDENTIALS_DICT_TOKEN_KEY])
credential_json = db_credentials.credential_json.get_value(apply_mask=False)
if DB_CREDENTIALS_DICT_TOKEN_KEY not in credential_json:
return AuthStatus(authenticated=False)
token_json_str = str(credential_json[DB_CREDENTIALS_DICT_TOKEN_KEY])
google_drive_creds = get_google_oauth_creds(
token_json_str=token_json_str,
source=DocumentSource.GOOGLE_DRIVE,
@@ -557,43 +557,6 @@ def _normalize_file_names_for_backwards_compatibility(
return file_names + file_locations[len(file_names) :]
def _fetch_and_check_file_connector_cc_pair_permissions(
connector_id: int,
user: User,
db_session: Session,
require_editable: bool,
) -> ConnectorCredentialPair:
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
if cc_pair is None:
raise HTTPException(
status_code=404,
detail="No Connector-Credential Pair found for this connector",
)
has_requested_access = verify_user_has_access_to_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
user=user,
get_editable=require_editable,
)
if has_requested_access:
return cc_pair
# Special case: global curators should be able to manage files
# for public file connectors even when they are not the creator.
if (
require_editable
and user.role == UserRole.GLOBAL_CURATOR
and cc_pair.access_type == AccessType.PUBLIC
):
return cc_pair
raise HTTPException(
status_code=403,
detail="Access denied. User cannot manage files for this connector.",
)
@router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS)
def upload_files_api(
files: list[UploadFile],
@@ -605,7 +568,7 @@ def upload_files_api(
@router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS)
def list_connector_files(
connector_id: int,
user: User = Depends(current_curator_or_admin_user),
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> ConnectorFilesResponse:
"""List all files in a file connector."""
@@ -618,13 +581,6 @@ def list_connector_files(
status_code=400, detail="This endpoint only works with file connectors"
)
_ = _fetch_and_check_file_connector_cc_pair_permissions(
connector_id=connector_id,
user=user,
db_session=db_session,
require_editable=False,
)
file_locations = connector.connector_specific_config.get("file_locations", [])
file_names = connector.connector_specific_config.get("file_names", [])
@@ -674,7 +630,7 @@ def update_connector_files(
connector_id: int,
files: list[UploadFile] | None = File(None),
file_ids_to_remove: str = Form("[]"),
user: User = Depends(current_curator_or_admin_user),
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> FileUploadResponse:
"""
@@ -692,13 +648,12 @@ def update_connector_files(
)
# Get the connector-credential pair for indexing/pruning triggers
# and validate user permissions for file management.
cc_pair = _fetch_and_check_file_connector_cc_pair_permissions(
connector_id=connector_id,
user=user,
db_session=db_session,
require_editable=True,
)
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
if cc_pair is None:
raise HTTPException(
status_code=404,
detail="No Connector-Credential Pair found for this connector",
)
# Parse file IDs to remove
try:

View File

@@ -346,10 +346,17 @@ def update_credential_from_model(
detail=f"Credential {credential_id} does not exist or does not belong to user",
)
# Get credential_json value - use masking for API responses
credential_json_value = (
updated_credential.credential_json.get_value(apply_mask=True)
if updated_credential.credential_json
else {}
)
return CredentialSnapshot(
source=updated_credential.source,
id=updated_credential.id,
credential_json=updated_credential.credential_json,
credential_json=credential_json_value,
user_id=updated_credential.user_id,
name=updated_credential.name,
admin_public=updated_credential.admin_public,

View File

@@ -28,7 +28,6 @@ from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import TaskStatus
from onyx.server.federated.models import FederatedConnectorStatus
from onyx.server.utils import mask_credential_dict
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -145,13 +144,21 @@ class CredentialSnapshot(CredentialBase):
@classmethod
def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot":
# Get the credential_json value with appropriate masking
if credential.credential_json is None:
credential_json_value: dict[str, Any] = {}
elif MASK_CREDENTIAL_PREFIX:
credential_json_value = credential.credential_json.get_value(
apply_mask=True
)
else:
credential_json_value = credential.credential_json.get_value(
apply_mask=False
)
return CredentialSnapshot(
id=credential.id,
credential_json=(
mask_credential_dict(credential.credential_json)
if MASK_CREDENTIAL_PREFIX and credential.credential_json
else credential.credential_json
),
credential_json=credential_json_value,
user_id=credential.user_id,
user_email=credential.user.email if credential.user else None,
admin_public=credential.admin_public,

View File

@@ -88,7 +88,7 @@ SANDBOX_NAMESPACE = os.environ.get("SANDBOX_NAMESPACE", "onyx-sandboxes")
# Container image for sandbox pods
# Should include Next.js template, opencode CLI, and demo_data zip
SANDBOX_CONTAINER_IMAGE = os.environ.get(
"SANDBOX_CONTAINER_IMAGE", "onyxdotapp/sandbox:v0.1.2"
"SANDBOX_CONTAINER_IMAGE", "onyxdotapp/sandbox:v0.1.3"
)
# S3 bucket for sandbox file storage (snapshots, knowledge files, uploads)

View File

@@ -0,0 +1,119 @@
# Sandbox Container Image
This directory contains the Dockerfile and resources for building the Onyx Craft sandbox container image.
## Directory Structure
```
docker/
├── Dockerfile # Main container image definition
├── demo_data.zip # Demo data (extracted to /workspace/demo_data)
├── templates/
│ └── outputs/ # Web app scaffold template (Next.js)
├── initial-requirements.txt # Python packages pre-installed in sandbox
├── generate_agents_md.py # Script to generate AGENTS.md for sessions
└── README.md # This file
```
## Building the Image
The sandbox image must be built for **amd64** architecture since our Kubernetes cluster runs on x86_64 nodes.
### Build for amd64 only (fastest)
```bash
cd backend/onyx/server/features/build/sandbox/kubernetes/docker
docker build --platform linux/amd64 -t onyxdotapp/sandbox:v0.1.x .
docker push onyxdotapp/sandbox:v0.1.x
```
### Build multi-arch (recommended for flexibility)
```bash
docker buildx build --platform linux/amd64,linux/arm64 \
-t onyxdotapp/sandbox:v0.1.x \
--push .
```
### Update the `latest` tag
After pushing a versioned tag, update `latest`:
```bash
docker tag onyxdotapp/sandbox:v0.1.x onyxdotapp/sandbox:latest
docker push onyxdotapp/sandbox:latest
```
Or with buildx:
```bash
docker buildx build --platform linux/amd64,linux/arm64 \
-t onyxdotapp/sandbox:v0.1.x \
-t onyxdotapp/sandbox:latest \
--push .
```
## Deploying a New Version
1. **Build and push** the new image (see above)
2. **Update the ConfigMap** in `cloud-deployment-yamls/danswer/configmap/env-configmap.yaml`:
```yaml
SANDBOX_CONTAINER_IMAGE: "onyxdotapp/sandbox:v0.1.x"
```
3. **Apply the ConfigMap**:
```bash
kubectl apply -f configmap/env-configmap.yaml
```
4. **Restart the API server** to pick up the new config:
```bash
kubectl rollout restart deployment/api-server -n danswer
```
5. **Delete existing sandbox pods** (they will be recreated with the new image):
```bash
kubectl delete pods -n onyx-sandboxes -l app.kubernetes.io/component=sandbox
```
## What's Baked Into the Image
- **Base**: `node:20-slim` (Debian-based)
- **Demo data**: `/workspace/demo_data/` - sample files for demo sessions
- **Templates**: `/workspace/templates/outputs/` - Next.js web app scaffold
- **Python venv**: `/workspace/.venv/` with packages from `initial-requirements.txt`
- **OpenCode CLI**: Installed in `/home/sandbox/.opencode/bin/`
## Runtime Directory Structure
When a session is created, the following structure is set up in the pod:
```
/workspace/
├── demo_data/ # Baked into image
├── files/ # Mounted volume, synced from S3
├── templates/ # Baked into image
└── sessions/
└── $session_id/
├── files/ # Symlink to /workspace/demo_data or /workspace/files
├── outputs/ # Copied from templates, contains web app
├── attachments/ # User-uploaded files
├── org_info/ # Demo persona info (if demo mode)
├── AGENTS.md # Instructions for the AI agent
└── opencode.json # OpenCode configuration
```
## Troubleshooting
### Verify image exists on Docker Hub
```bash
curl -s "https://hub.docker.com/v2/repositories/onyxdotapp/sandbox/tags" | jq '.results[].name'
```
### Check what image a pod is using
```bash
kubectl get pod <pod-name> -n onyx-sandboxes -o jsonpath='{.spec.containers[?(@.name=="sandbox")].image}'
```

View File

@@ -349,7 +349,11 @@ class SessionManager:
return LLMProviderConfig(
provider=default_model.llm_provider.provider,
model_name=default_model.name,
api_key=default_model.llm_provider.api_key,
api_key=(
default_model.llm_provider.api_key.get_value(apply_mask=False)
if default_model.llm_provider.api_key
else None
),
api_base=default_model.llm_provider.api_base,
)

View File

@@ -41,6 +41,7 @@ from onyx.db.mcp import delete_all_user_connection_configs_for_server_no_commit
from onyx.db.mcp import delete_connection_config
from onyx.db.mcp import delete_mcp_server
from onyx.db.mcp import delete_user_connection_configs_for_server
from onyx.db.mcp import extract_connection_data
from onyx.db.mcp import get_all_mcp_servers
from onyx.db.mcp import get_connection_config_by_id
from onyx.db.mcp import get_mcp_server_by_id
@@ -79,6 +80,7 @@ from onyx.server.features.tool.models import ToolSnapshot
from onyx.tools.tool_implementations.mcp.mcp_client import discover_mcp_tools
from onyx.tools.tool_implementations.mcp.mcp_client import initialize_mcp_client
from onyx.tools.tool_implementations.mcp.mcp_client import log_exception_group
from onyx.utils.encryption import mask_string
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -143,7 +145,8 @@ class OnyxTokenStorage(TokenStorage):
async def get_tokens(self) -> OAuthToken | None:
with get_session_with_current_tenant() as db_session:
config = self._ensure_connection_config(db_session)
tokens_raw = config.config.get(MCPOAuthKeys.TOKENS.value)
config_data = extract_connection_data(config)
tokens_raw = config_data.get(MCPOAuthKeys.TOKENS.value)
if tokens_raw:
return OAuthToken.model_validate(tokens_raw)
return None
@@ -151,14 +154,14 @@ class OnyxTokenStorage(TokenStorage):
async def set_tokens(self, tokens: OAuthToken) -> None:
with get_session_with_current_tenant() as db_session:
config = self._ensure_connection_config(db_session)
config.config[MCPOAuthKeys.TOKENS.value] = tokens.model_dump(mode="json")
cfg_headers = {
config_data = extract_connection_data(config)
config_data[MCPOAuthKeys.TOKENS.value] = tokens.model_dump(mode="json")
config_data["headers"] = {
"Authorization": f"{tokens.token_type} {tokens.access_token}"
}
config.config["headers"] = cfg_headers
update_connection_config(config.id, db_session, config.config)
update_connection_config(config.id, db_session, config_data)
if self.alt_config_id:
update_connection_config(self.alt_config_id, db_session, config.config)
update_connection_config(self.alt_config_id, db_session, config_data)
# signal the oauth callback that token exchange is complete
r = get_redis_client()
@@ -168,19 +171,21 @@ class OnyxTokenStorage(TokenStorage):
async def get_client_info(self) -> OAuthClientInformationFull | None:
with get_session_with_current_tenant() as db_session:
config = self._ensure_connection_config(db_session)
client_info_raw = config.config.get(MCPOAuthKeys.CLIENT_INFO.value)
config_data = extract_connection_data(config)
client_info_raw = config_data.get(MCPOAuthKeys.CLIENT_INFO.value)
if client_info_raw:
return OAuthClientInformationFull.model_validate(client_info_raw)
if self.alt_config_id:
alt_config = get_connection_config_by_id(self.alt_config_id, db_session)
if alt_config:
alt_client_info = alt_config.config.get(
alt_config_data = extract_connection_data(alt_config)
alt_client_info = alt_config_data.get(
MCPOAuthKeys.CLIENT_INFO.value
)
if alt_client_info:
# Cache the admin client info on the user config for future calls
config.config[MCPOAuthKeys.CLIENT_INFO.value] = alt_client_info
update_connection_config(config.id, db_session, config.config)
config_data[MCPOAuthKeys.CLIENT_INFO.value] = alt_client_info
update_connection_config(config.id, db_session, config_data)
return OAuthClientInformationFull.model_validate(
alt_client_info
)
@@ -189,10 +194,11 @@ class OnyxTokenStorage(TokenStorage):
async def set_client_info(self, info: OAuthClientInformationFull) -> None:
with get_session_with_current_tenant() as db_session:
config = self._ensure_connection_config(db_session)
config.config[MCPOAuthKeys.CLIENT_INFO.value] = info.model_dump(mode="json")
update_connection_config(config.id, db_session, config.config)
config_data = extract_connection_data(config)
config_data[MCPOAuthKeys.CLIENT_INFO.value] = info.model_dump(mode="json")
update_connection_config(config.id, db_session, config_data)
if self.alt_config_id:
update_connection_config(self.alt_config_id, db_session, config.config)
update_connection_config(self.alt_config_id, db_session, config_data)
def make_oauth_provider(
@@ -436,9 +442,12 @@ async def _connect_oauth(
db.commit()
connection_config_dict = extract_connection_data(
connection_config, apply_mask=False
)
is_connected = (
MCPOAuthKeys.CLIENT_INFO.value in connection_config.config
and connection_config.config.get("headers")
MCPOAuthKeys.CLIENT_INFO.value in connection_config_dict
and connection_config_dict.get("headers")
)
# Step 1: make unauthenticated request and parse returned www authenticate header
# Ensure we have a trailing slash for the MCP endpoint
@@ -471,7 +480,7 @@ async def _connect_oauth(
try:
x = await initialize_mcp_client(
probe_url,
connection_headers=connection_config.config.get("headers", {}),
connection_headers=connection_config_dict.get("headers", {}),
transport=transport,
auth=oauth_auth,
)
@@ -684,15 +693,18 @@ def save_user_credentials(
# Use template to create the full connection config
try:
# TODO: fix and/or type correctly w/base model
auth_template_dict = extract_connection_data(
auth_template, apply_mask=False
)
config_data = MCPConnectionData(
headers=auth_template.config.get("headers", {}),
headers=auth_template_dict.get("headers", {}),
header_substitutions=request.credentials,
)
for oauth_field_key in MCPOAuthKeys:
field_key: Literal["client_info", "tokens", "metadata"] = (
oauth_field_key.value
)
if field_val := auth_template.config.get(field_key):
if field_val := auth_template_dict.get(field_key):
config_data[field_key] = field_val
except Exception as e:
@@ -839,18 +851,20 @@ def _db_mcp_server_to_api_mcp_server(
and db_server.admin_connection_config is not None
and include_auth_config
):
admin_config_dict = extract_connection_data(
db_server.admin_connection_config, apply_mask=False
)
if db_server.auth_type == MCPAuthenticationType.API_TOKEN:
raw_api_key = admin_config_dict["headers"]["Authorization"].split(" ")[
-1
]
admin_credentials = {
"api_key": db_server.admin_connection_config.config["headers"][
"Authorization"
].split(" ")[-1]
"api_key": mask_string(raw_api_key),
}
elif db_server.auth_type == MCPAuthenticationType.OAUTH:
user_authenticated = False
client_info = None
client_info_raw = db_server.admin_connection_config.config.get(
MCPOAuthKeys.CLIENT_INFO.value
)
client_info_raw = admin_config_dict.get(MCPOAuthKeys.CLIENT_INFO.value)
if client_info_raw:
client_info = OAuthClientInformationFull.model_validate(
client_info_raw
@@ -861,8 +875,8 @@ def _db_mcp_server_to_api_mcp_server(
"Stored client info had empty client ID or secret"
)
admin_credentials = {
"client_id": client_info.client_id,
"client_secret": client_info.client_secret,
"client_id": mask_string(client_info.client_id),
"client_secret": mask_string(client_info.client_secret),
}
else:
admin_credentials = {}
@@ -879,14 +893,18 @@ def _db_mcp_server_to_api_mcp_server(
include_auth_config
and db_server.auth_type != MCPAuthenticationType.OAUTH
):
user_credentials = user_config.config.get(HEADER_SUBSTITUTIONS, {})
user_config_dict = extract_connection_data(user_config, apply_mask=True)
user_credentials = user_config_dict.get(HEADER_SUBSTITUTIONS, {})
if (
db_server.auth_type == MCPAuthenticationType.OAUTH
and db_server.admin_connection_config
):
client_info = None
client_info_raw = db_server.admin_connection_config.config.get(
oauth_admin_config_dict = extract_connection_data(
db_server.admin_connection_config, apply_mask=False
)
client_info_raw = oauth_admin_config_dict.get(
MCPOAuthKeys.CLIENT_INFO.value
)
if client_info_raw:
@@ -896,8 +914,8 @@ def _db_mcp_server_to_api_mcp_server(
raise ValueError("Stored client info had empty client ID or secret")
if can_view_admin_credentials:
admin_credentials = {
"client_id": client_info.client_id,
"client_secret": client_info.client_secret,
"client_id": mask_string(client_info.client_id),
"client_secret": mask_string(client_info.client_secret),
}
elif can_view_admin_credentials:
admin_credentials = {}
@@ -909,7 +927,10 @@ def _db_mcp_server_to_api_mcp_server(
try:
template_config = db_server.admin_connection_config
if template_config:
headers = template_config.config.get("headers", {})
template_config_dict = extract_connection_data(
template_config, apply_mask=False
)
headers = template_config_dict.get("headers", {})
auth_template = MCPAuthTemplate(
headers=headers,
required_fields=[], # would need to regex, not worth it
@@ -1232,7 +1253,10 @@ def _list_mcp_tools_by_id(
)
if connection_config:
headers.update(connection_config.config.get("headers", {}))
connection_config_dict = extract_connection_data(
connection_config, apply_mask=False
)
headers.update(connection_config_dict.get("headers", {}))
import time
@@ -1320,7 +1344,10 @@ def _upsert_mcp_server(
_ensure_mcp_server_owner_or_admin(mcp_server, user)
client_info = None
if mcp_server.admin_connection_config:
client_info_raw = mcp_server.admin_connection_config.config.get(
existing_admin_config_dict = extract_connection_data(
mcp_server.admin_connection_config, apply_mask=False
)
client_info_raw = existing_admin_config_dict.get(
MCPOAuthKeys.CLIENT_INFO.value
)
if client_info_raw:

View File

@@ -32,11 +32,16 @@ def get_user_oauth_token_status(
and whether their tokens are expired.
"""
user_tokens = get_all_user_oauth_tokens(user.id, db_session)
return [
OAuthTokenStatus(
oauth_config_id=token.oauth_config_id,
expires_at=OAuthTokenManager.token_expiration_time(token.token_data),
is_expired=OAuthTokenManager.is_token_expired(token.token_data),
result = []
for token in user_tokens:
token_data = (
token.token_data.get_value(apply_mask=False) if token.token_data else {}
)
for token in user_tokens
]
result.append(
OAuthTokenStatus(
oauth_config_id=token.oauth_config_id,
expires_at=OAuthTokenManager.token_expiration_time(token_data),
is_expired=OAuthTokenManager.is_token_expired(token_data),
)
)
return result

View File

@@ -75,7 +75,7 @@ def _get_active_search_provider(
has_api_key=bool(provider_model.api_key),
)
if not provider_model.api_key:
if provider_model.api_key is None:
raise HTTPException(
status_code=400,
detail="Web search provider requires an API key.",
@@ -84,7 +84,7 @@ def _get_active_search_provider(
try:
provider: WebSearchProvider = build_search_provider_from_config(
provider_type=provider_view.provider_type,
api_key=provider_model.api_key,
api_key=provider_model.api_key.get_value(apply_mask=False),
config=provider_model.config or {},
)
except ValueError as exc:
@@ -121,7 +121,7 @@ def _get_active_content_provider(
provider: WebContentProvider | None = build_content_provider_from_config(
provider_type=provider_type,
api_key=provider_model.api_key,
api_key=provider_model.api_key.get_value(apply_mask=False),
config=config,
)
except ValueError as exc:

View File

@@ -114,9 +114,14 @@ def get_entities(
federated_connector = fetch_federated_connector_by_id(id, db_session)
if not federated_connector:
raise HTTPException(status_code=404, detail="Federated connector not found")
if federated_connector.credentials is None:
raise HTTPException(
status_code=400, detail="Federated connector has no credentials"
)
connector_instance = _get_federated_connector_instance(
federated_connector.source, federated_connector.credentials
federated_connector.source,
federated_connector.credentials.get_value(apply_mask=False),
)
entities_spec = connector_instance.configuration_schema()
@@ -151,9 +156,14 @@ def get_credentials_schema(
federated_connector = fetch_federated_connector_by_id(id, db_session)
if not federated_connector:
raise HTTPException(status_code=404, detail="Federated connector not found")
if federated_connector.credentials is None:
raise HTTPException(
status_code=400, detail="Federated connector has no credentials"
)
connector_instance = _get_federated_connector_instance(
federated_connector.source, federated_connector.credentials
federated_connector.source,
federated_connector.credentials.get_value(apply_mask=False),
)
credentials_spec = connector_instance.credentials_schema()
@@ -275,6 +285,8 @@ def validate_entities(
federated_connector = fetch_federated_connector_by_id(id, db_session)
if not federated_connector:
raise HTTPException(status_code=404, detail="Federated connector not found")
if federated_connector.credentials is None:
return Response(status_code=400)
# For HEAD requests, we'll expect entities as query parameters
# since HEAD requests shouldn't have request bodies
@@ -288,7 +300,8 @@ def validate_entities(
return Response(status_code=400)
connector_instance = _get_federated_connector_instance(
federated_connector.source, federated_connector.credentials
federated_connector.source,
federated_connector.credentials.get_value(apply_mask=False),
)
is_valid = connector_instance.validate_entities(entities_dict)
@@ -318,9 +331,15 @@ def get_authorize_url(
federated_connector = fetch_federated_connector_by_id(id, db_session)
if not federated_connector:
raise HTTPException(status_code=404, detail="Federated connector not found")
if federated_connector.credentials is None:
raise HTTPException(
status_code=400, detail="Federated connector has no credentials"
)
# Update credentials to include the correct redirect URI with the connector ID
updated_credentials = federated_connector.credentials.copy()
updated_credentials = federated_connector.credentials.get_value(
apply_mask=False
).copy()
if "redirect_uri" in updated_credentials and updated_credentials["redirect_uri"]:
# Replace the {id} placeholder with the actual federated connector ID
updated_credentials["redirect_uri"] = updated_credentials[
@@ -391,9 +410,14 @@ def handle_oauth_callback_generic(
)
if not federated_connector:
raise HTTPException(status_code=404, detail="Federated connector not found")
if federated_connector.credentials is None:
raise HTTPException(
status_code=400, detail="Federated connector has no credentials"
)
connector_instance = _get_federated_connector_instance(
federated_connector.source, federated_connector.credentials
federated_connector.source,
federated_connector.credentials.get_value(apply_mask=False),
)
oauth_result = connector_instance.callback(callback_data, get_oauth_callback_uri())
@@ -460,9 +484,9 @@ def get_user_oauth_status(
# Generate authorize URL if needed
authorize_url = None
if not oauth_token:
if not oauth_token and fc.credentials is not None:
connector_instance = _get_federated_connector_instance(
fc.source, fc.credentials
fc.source, fc.credentials.get_value(apply_mask=False)
)
base_authorize_url = connector_instance.authorize(get_oauth_callback_uri())
@@ -496,6 +520,10 @@ def get_federated_connector_detail(
federated_connector = fetch_federated_connector_by_id(id, db_session)
if not federated_connector:
raise HTTPException(status_code=404, detail="Federated connector not found")
if federated_connector.credentials is None:
raise HTTPException(
status_code=400, detail="Federated connector has no credentials"
)
# Get OAuth token information for the current user
oauth_token = None
@@ -521,7 +549,9 @@ def get_federated_connector_detail(
id=federated_connector.id,
source=federated_connector.source,
name=f"{federated_connector.source.replace('_', ' ').title()}",
credentials=FederatedConnectorCredentials(**federated_connector.credentials),
credentials=FederatedConnectorCredentials(
**federated_connector.credentials.get_value(apply_mask=True)
),
config=federated_connector.config,
oauth_token_exists=oauth_token is not None,
oauth_token_expires_at=oauth_token.expires_at if oauth_token else None,

View File

@@ -16,7 +16,7 @@ from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.embedding.models import TestEmbeddingRequest
from onyx.server.utils import mask_string
from onyx.utils.encryption import mask_string
from onyx.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT

View File

@@ -37,7 +37,11 @@ class CloudEmbeddingProvider(BaseModel):
) -> "CloudEmbeddingProvider":
return cls(
provider_type=cloud_provider_model.provider_type,
api_key=cloud_provider_model.api_key,
api_key=(
cloud_provider_model.api_key.get_value(apply_mask=True)
if cloud_provider_model.api_key
else None
),
api_url=cloud_provider_model.api_url,
api_version=cloud_provider_model.api_version,
deployment_name=cloud_provider_model.deployment_name,

View File

@@ -90,7 +90,11 @@ def _build_llm_provider_request(
return LLMProviderUpsertRequest(
name=f"Image Gen - {image_provider_id}",
provider=source_provider.provider,
api_key=source_provider.api_key, # Only this from source
api_key=(
source_provider.api_key.get_value(apply_mask=False)
if source_provider.api_key
else None
), # Only this from source
api_base=api_base, # From request
api_version=api_version, # From request
default_model_name=model_name,
@@ -227,7 +231,11 @@ def test_image_generation(
api_key_changed=False, # Using stored key from source provider
)
api_key = source_provider.api_key
api_key = (
source_provider.api_key.get_value(apply_mask=False)
if source_provider.api_key
else None
)
provider = source_provider.provider
if provider is None:
@@ -431,7 +439,11 @@ def update_config(
api_key_changed=False,
)
# Preserve existing API key when user didn't change it
actual_api_key = old_provider.api_key
actual_api_key = (
old_provider.api_key.get_value(apply_mask=False)
if old_provider.api_key
else None
)
# 3. Build and create new LLM provider
provider_request = _build_llm_provider_request(

View File

@@ -140,7 +140,11 @@ class ImageGenerationCredentials(BaseModel):
"""
llm_provider = config.model_configuration.llm_provider
return cls(
api_key=_mask_api_key(llm_provider.api_key),
api_key=_mask_api_key(
llm_provider.api_key.get_value(apply_mask=False)
if llm_provider.api_key
else None
),
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
deployment_name=llm_provider.deployment_name,
@@ -168,7 +172,11 @@ class DefaultImageGenerationConfig(BaseModel):
model_configuration_id=config.model_configuration_id,
model_name=config.model_configuration.name,
provider=llm_provider.provider,
api_key=llm_provider.api_key,
api_key=(
llm_provider.api_key.get_value(apply_mask=False)
if llm_provider.api_key
else None
),
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
deployment_name=llm_provider.deployment_name,

View File

@@ -203,7 +203,11 @@ def test_llm_configuration(
new_custom_config=test_llm_request.custom_config,
api_key_changed=False,
)
test_api_key = existing_provider.api_key
test_api_key = (
existing_provider.api_key.get_value(apply_mask=False)
if existing_provider.api_key
else None
)
if existing_provider and not test_llm_request.custom_config_changed:
test_custom_config = existing_provider.custom_config
@@ -255,7 +259,7 @@ def list_llm_providers(
llm_provider_list: list[LLMProviderView] = []
for llm_provider_model in fetch_existing_llm_providers(
db_session=db_session,
flow_type_filter=[],
flow_types=[LLMModelFlowType.CHAT, LLMModelFlowType.VISION],
exclude_image_generation_providers=not include_image_gen,
):
from_model_start = datetime.now(timezone.utc)
@@ -351,7 +355,11 @@ def put_llm_provider(
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider_upsert_request.api_key_changed:
llm_provider_upsert_request.api_key = existing_provider.api_key
llm_provider_upsert_request.api_key = (
existing_provider.api_key.get_value(apply_mask=False)
if existing_provider.api_key
else None
)
if existing_provider and not llm_provider_upsert_request.custom_config_changed:
llm_provider_upsert_request.custom_config = existing_provider.custom_config
@@ -503,7 +511,9 @@ def list_llm_provider_basics(
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch user-accessible LLM providers")
all_providers = fetch_existing_llm_providers(db_session, [])
all_providers = fetch_existing_llm_providers(
db_session, [LLMModelFlowType.CHAT, LLMModelFlowType.VISION]
)
user_group_ids = fetch_user_group_ids(db_session, user)
is_admin = user.role == UserRole.ADMIN
@@ -512,9 +522,9 @@ def list_llm_provider_basics(
for provider in all_providers:
# Use centralized access control logic with persona=None since we're
# listing providers without a specific persona context. This correctly:
# - Includes public providers WITHOUT persona restrictions
# - Includes all public providers
# - Includes providers user can access via group membership
# - Excludes providers with persona restrictions (requires specific persona)
# - Excludes persona-only restricted providers (requires specific persona)
# - Excludes non-public providers with no restrictions (admin-only)
if can_user_access_llm_provider(
provider, user_group_ids, persona=None, is_admin=is_admin
@@ -539,7 +549,7 @@ def get_valid_model_names_for_persona(
Returns a list of model names (e.g., ["gpt-4o", "claude-3-5-sonnet"]) that are
available to the user when using this persona, respecting all RBAC restrictions.
Public providers are included unless they have persona restrictions that exclude this persona.
Public providers are always included.
"""
persona = fetch_persona_with_groups(db_session, persona_id)
if not persona:
@@ -553,7 +563,7 @@ def get_valid_model_names_for_persona(
valid_models = []
for llm_provider_model in all_providers:
# Check access with persona context — respects all RBAC restrictions
# Public providers always included, restricted checked via RBAC
if can_user_access_llm_provider(
llm_provider_model, user_group_ids, persona, is_admin=is_admin
):
@@ -574,7 +584,7 @@ def list_llm_providers_for_persona(
"""Get LLM providers for a specific persona.
Returns providers that the user can access when using this persona:
- Public providers (respecting persona restrictions if set)
- All public providers (is_public=True) - ALWAYS included
- Restricted providers user can access via group/persona restrictions
This endpoint is used for background fetching of restricted providers
@@ -603,7 +613,7 @@ def list_llm_providers_for_persona(
llm_provider_list: list[LLMProviderDescriptor] = []
for llm_provider_model in all_providers:
# Check access with persona context — respects persona restrictions
# Use simplified access check - public providers always included
if can_user_access_llm_provider(
llm_provider_model, user_group_ids, persona, is_admin=is_admin
):
@@ -644,7 +654,11 @@ def get_provider_contextual_cost(
provider=provider.provider,
model=model_configuration.name,
deployment_name=provider.deployment_name,
api_key=provider.api_key,
api_key=(
provider.api_key.get_value(apply_mask=False)
if provider.api_key
else None
),
api_base=provider.api_base,
api_version=provider.api_version,
custom_config=provider.custom_config,
@@ -924,6 +938,11 @@ def get_ollama_available_models(
)
)
sorted_results = sorted(
all_models_with_context_size_and_vision,
key=lambda m: m.name.lower(),
)
# Sync new models to DB if provider_name is specified
if request.provider_name:
try:
@@ -934,7 +953,7 @@ def get_ollama_available_models(
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in all_models_with_context_size_and_vision
for r in sorted_results
]
new_count = sync_model_configurations(
db_session=db_session,
@@ -948,7 +967,7 @@ def get_ollama_available_models(
except ValueError as e:
logger.warning(f"Failed to sync Ollama models to DB: {e}")
return all_models_with_context_size_and_vision
return sorted_results
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:

View File

@@ -190,7 +190,11 @@ class LLMProviderView(LLMProvider):
id=llm_provider_model.id,
name=llm_provider_model.name,
provider=provider,
api_key=llm_provider_model.api_key,
api_key=(
llm_provider_model.api_key.get_value(apply_mask=False)
if llm_provider_model.api_key
else None
),
api_base=llm_provider_model.api_base,
api_version=llm_provider_model.api_version,
custom_config=llm_provider_model.custom_config,

View File

@@ -79,6 +79,7 @@ class UserPersonalization(BaseModel):
role: str = ""
use_memories: bool = True
memories: list[str] = Field(default_factory=list)
user_preferences: str = ""
class TenantSnapshot(BaseModel):
@@ -160,6 +161,7 @@ class UserInfo(BaseModel):
role=user.personal_role or "",
use_memories=user.use_memories,
memories=[memory.memory_text for memory in (user.memories or [])],
user_preferences=user.user_preferences or "",
),
)
@@ -213,6 +215,7 @@ class PersonalizationUpdateRequest(BaseModel):
role: str | None = None
use_memories: bool | None = None
memories: list[str] | None = None
user_preferences: str | None = Field(default=None, max_length=500)
class SlackBotCreationRequest(BaseModel):
@@ -341,9 +344,21 @@ class SlackBot(BaseModel):
name=slack_bot_model.name,
enabled=slack_bot_model.enabled,
configs_count=len(slack_bot_model.slack_channel_configs),
bot_token=slack_bot_model.bot_token,
app_token=slack_bot_model.app_token,
user_token=slack_bot_model.user_token,
bot_token=(
slack_bot_model.bot_token.get_value(apply_mask=True)
if slack_bot_model.bot_token
else ""
),
app_token=(
slack_bot_model.app_token.get_value(apply_mask=True)
if slack_bot_model.app_token
else ""
),
user_token=(
slack_bot_model.user_token.get_value(apply_mask=True)
if slack_bot_model.user_token
else None
),
)

View File

@@ -30,14 +30,12 @@ from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import enforce_seat_limit
from onyx.auth.users import optional_user
from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import AuthBackend
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
from onyx.configs.app_configs import NUM_FREE_TRIAL_USER_INVITES
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import USER_AUTH_SECRET
@@ -92,7 +90,6 @@ from onyx.server.manage.models import UserSpecificAssistantPreferences
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
from onyx.server.models import MinimalUserSnapshot
from onyx.server.usage_limits import is_tenant_on_trial_fn
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -394,20 +391,14 @@ def bulk_invite_users(
if e not in existing_users and e not in already_invited
]
# Limit bulk invites for trial tenants to prevent email spam
# Only count new invites, not re-invites of existing users
if MULTI_TENANT and is_tenant_on_trial_fn(tenant_id):
current_invited = len(already_invited)
if current_invited + len(emails_needing_seats) > NUM_FREE_TRIAL_USER_INVITES:
raise HTTPException(
status_code=403,
detail="You have hit your invite limit. "
"Please upgrade for unlimited invites.",
)
# Check seat availability for new users
if emails_needing_seats:
enforce_seat_limit(db_session, seats_needed=len(emails_needing_seats))
# Only for self-hosted (non-multi-tenant) deployments
if not MULTI_TENANT and emails_needing_seats:
result = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)(db_session, seats_needed=len(emails_needing_seats))
if result is not None and not result.available:
raise HTTPException(status_code=402, detail=result.error_message)
if MULTI_TENANT:
try:
@@ -423,10 +414,10 @@ def bulk_invite_users(
all_emails = list(set(new_invited_emails) | set(initial_invited_users))
number_of_invited_users = write_invited_users(all_emails)
# send out email invitations only to new users (not already invited or existing)
# send out email invitations if enabled
if ENABLE_EMAIL_INVITES:
try:
for email in emails_needing_seats:
for email in new_invited_emails:
send_user_email_invite(email, current_user, AUTH_TYPE)
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
@@ -573,7 +564,12 @@ def activate_user_api(
# Check seat availability before activating
# Only for self-hosted (non-multi-tenant) deployments
enforce_seat_limit(db_session)
if not MULTI_TENANT:
result = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)(db_session, seats_needed=1)
if result is not None and not result.available:
raise HTTPException(status_code=402, detail=result.error_message)
activate_user(user_to_activate, db_session)
@@ -597,17 +593,11 @@ def get_valid_domains(
@router.get("/users", tags=PUBLIC_API_TAGS)
def list_all_users_basic_info(
include_api_keys: bool = False,
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[MinimalUserSnapshot]:
users = get_all_users(db_session)
return [
MinimalUserSnapshot(id=user.id, email=user.email)
for user in users
if user.role != UserRole.SLACK_USER
and (include_api_keys or not is_api_key_email_address(user.email))
]
return [MinimalUserSnapshot(id=user.id, email=user.email) for user in users]
@router.get("/get-user-role", tags=PUBLIC_API_TAGS)
@@ -854,6 +844,11 @@ def update_user_personalization_api(
new_memories = (
request.memories if request.memories is not None else existing_memories
)
new_user_preferences = (
request.user_preferences
if request.user_preferences is not None
else user.user_preferences
)
update_user_personalization(
user.id,
@@ -861,6 +856,7 @@ def update_user_personalization_api(
personal_role=new_role,
use_memories=new_use_memories,
memories=new_memories,
user_preferences=new_user_preferences,
db_session=db_session,
)

View File

@@ -194,7 +194,7 @@ def test_search_provider(
status_code=400,
detail="No stored API key found for this provider type.",
)
api_key = existing_provider.api_key
api_key = existing_provider.api_key.get_value(apply_mask=False)
if provider_requires_api_key and not api_key:
raise HTTPException(
@@ -391,7 +391,7 @@ def test_content_provider(
detail="Base URL cannot differ from stored provider when using stored API key",
)
api_key = existing_provider.api_key
api_key = existing_provider.api_key.get_value(apply_mask=False)
if not api_key:
raise HTTPException(

View File

@@ -57,11 +57,9 @@ class Settings(BaseModel):
anonymous_user_enabled: bool | None = None
deep_research_enabled: bool | None = None
# Whether EE features are unlocked for use.
# Depends on license status: True when the user has a valid license
# (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER), False when there's no license
# or the license is expired (GATED_ACCESS).
# This controls UI visibility of EE features (user groups, analytics, RBAC, etc.).
# Enterprise features flag - set by license enforcement at runtime
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
# When LICENSE_ENFORCEMENT_ENABLED=false, defaults to False
ee_features_enabled: bool = False
temperature_override_enabled: bool | None = False

View File

@@ -8,10 +8,6 @@ from uuid import UUID
from fastapi import HTTPException
from fastapi import status
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
@@ -45,42 +41,6 @@ def get_json_line(
return json.dumps(json_dict, cls=encoder) + "\n"
def mask_string(sensitive_str: str) -> str:
return "****...**" + sensitive_str[-4:]
MASK_CREDENTIALS_WHITELIST = {
DB_CREDENTIALS_AUTHENTICATION_METHOD,
"wiki_base",
"cloud_name",
"cloud_id",
}
def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
masked_creds = {}
for key, val in credential_dict.items():
if isinstance(val, str):
# we want to pass the authentication_method field through so the frontend
# can disambiguate credentials created by different methods
if key in MASK_CREDENTIALS_WHITELIST:
masked_creds[key] = val
else:
masked_creds[key] = mask_string(val)
continue
if isinstance(val, int):
masked_creds[key] = "*****"
continue
raise ValueError(
f"Unable to mask credentials of type other than string or int, cannot process request."
f"Received type: {type(val)}"
)
return masked_creds
def make_short_id() -> str:
"""Fast way to generate a random 8 character id ... useful for tagging data
to trace it through a flow. This is definitely not guaranteed to be unique and is

View File

@@ -446,7 +446,7 @@ def run_research_agent_call(
tool_calls=tool_calls,
tools=current_tools,
message_history=msg_history,
memories=None,
user_memory_context=None,
user_info=None,
citation_mapping=citation_mapping,
next_citation_num=citation_processor.get_next_citation_number(),

View File

@@ -16,6 +16,7 @@ from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchDocsResponse
from onyx.db.memory import UserMemoryContext
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.tools.tool_implementations.images.models import FinalImageGenerationResponse
@@ -165,7 +166,7 @@ class SearchToolOverrideKwargs(BaseModel):
# without help and a specific custom prompt for this
original_query: str | None = None
message_history: list[ChatMinimalTextMessage] | None = None
memories: list[str] | None = None
user_memory_context: UserMemoryContext | None = None
user_info: str | None = None
# Used for tool calls after the first one but in the same chat turn. The reason for this is that if the initial pass through

View File

@@ -82,7 +82,11 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
model_provider=llm_provider.provider,
model_name=default_config.model_configuration.name,
temperature=GEN_AI_TEMPERATURE,
api_key=llm_provider.api_key,
api_key=(
llm_provider.api_key.get_value(apply_mask=False)
if llm_provider.api_key
else None
),
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
deployment_name=llm_provider.deployment_name,

View File

@@ -94,7 +94,11 @@ class ImageGenerationTool(Tool[None]):
llm_provider = config.model_configuration.llm_provider
credentials = ImageGenerationProviderCredentials(
api_key=llm_provider.api_key,
api_key=(
llm_provider.api_key.get_value(apply_mask=False)
if llm_provider.api_key
else None
),
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
deployment_name=llm_provider.deployment_name,

View File

@@ -142,8 +142,9 @@ class MCPTool(Tool[None]):
)
# Priority 2: Base headers from connection config (DB) - overrides request
if self.connection_config:
headers.update(self.connection_config.config.get("headers", {}))
if self.connection_config and self.connection_config.config:
config_dict = self.connection_config.config.get_value(apply_mask=False)
headers.update(config_dict.get("headers", {}))
# Priority 3: For pass-through OAuth, use the user's login OAuth token
if self._user_oauth_token:

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from onyx.file_processing.html_utils import ParsedHTML
from onyx.file_processing.html_utils import web_html_cleanup
@@ -22,22 +21,10 @@ from onyx.utils.web_content import title_from_url
logger = setup_logger()
DEFAULT_READ_TIMEOUT_SECONDS = 15
DEFAULT_CONNECT_TIMEOUT_SECONDS = 5
DEFAULT_TIMEOUT_SECONDS = 15
DEFAULT_USER_AGENT = "OnyxWebCrawler/1.0 (+https://www.onyx.app)"
DEFAULT_MAX_PDF_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB
DEFAULT_MAX_HTML_SIZE_BYTES = 20 * 1024 * 1024 # 20 MB
DEFAULT_MAX_WORKERS = 5
def _failed_result(url: str) -> WebContent:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
class OnyxWebCrawler(WebContentProvider):
@@ -50,14 +37,12 @@ class OnyxWebCrawler(WebContentProvider):
def __init__(
self,
*,
timeout_seconds: int = DEFAULT_READ_TIMEOUT_SECONDS,
connect_timeout_seconds: int = DEFAULT_CONNECT_TIMEOUT_SECONDS,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
user_agent: str = DEFAULT_USER_AGENT,
max_pdf_size_bytes: int | None = None,
max_html_size_bytes: int | None = None,
) -> None:
self._read_timeout_seconds = timeout_seconds
self._connect_timeout_seconds = connect_timeout_seconds
self._timeout_seconds = timeout_seconds
self._max_pdf_size_bytes = max_pdf_size_bytes
self._max_html_size_bytes = max_html_size_bytes
self._headers = {
@@ -66,68 +51,75 @@ class OnyxWebCrawler(WebContentProvider):
}
def contents(self, urls: Sequence[str]) -> list[WebContent]:
if not urls:
return []
max_workers = min(DEFAULT_MAX_WORKERS, len(urls))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
return list(executor.map(self._fetch_url_safe, urls))
def _fetch_url_safe(self, url: str) -> WebContent:
"""Wrapper that catches all exceptions so one bad URL doesn't kill the batch."""
try:
return self._fetch_url(url)
except Exception as exc:
logger.warning(
"Onyx crawler unexpected error for %s (%s)",
url,
exc.__class__.__name__,
)
return _failed_result(url)
results: list[WebContent] = []
for url in urls:
results.append(self._fetch_url(url))
return results
def _fetch_url(self, url: str) -> WebContent:
try:
# Use SSRF-safe request to prevent DNS rebinding attacks
response = ssrf_safe_get(
url,
headers=self._headers,
timeout=(self._connect_timeout_seconds, self._read_timeout_seconds),
url, headers=self._headers, timeout=self._timeout_seconds
)
except SSRFException as exc:
logger.error(
"SSRF protection blocked request to %s (%s)",
"SSRF protection blocked request to %s: %s",
url,
exc.__class__.__name__,
str(exc),
)
return _failed_result(url)
except Exception as exc:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
except Exception as exc: # pragma: no cover - network failures vary
logger.warning(
"Onyx crawler failed to fetch %s (%s)",
url,
exc.__class__.__name__,
)
return _failed_result(url)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
if response.status_code >= 400:
logger.warning("Onyx crawler received %s for %s", response.status_code, url)
return _failed_result(url)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
content_type = response.headers.get("Content-Type", "")
content = response.content
content_sniff = content[:1024] if content else None
content_sniff = response.content[:1024] if response.content else None
if is_pdf_resource(url, content_type, content_sniff):
if (
self._max_pdf_size_bytes is not None
and len(content) > self._max_pdf_size_bytes
and len(response.content) > self._max_pdf_size_bytes
):
logger.warning(
"PDF content too large (%d bytes) for %s, max is %d",
len(content),
len(response.content),
url,
self._max_pdf_size_bytes,
)
return _failed_result(url)
text_content, metadata = extract_pdf_text(content)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
text_content, metadata = extract_pdf_text(response.content)
title = title_from_pdf_metadata(metadata) or title_from_url(url)
return WebContent(
title=title,
@@ -139,19 +131,25 @@ class OnyxWebCrawler(WebContentProvider):
if (
self._max_html_size_bytes is not None
and len(content) > self._max_html_size_bytes
and len(response.content) > self._max_html_size_bytes
):
logger.warning(
"HTML content too large (%d bytes) for %s, max is %d",
len(content),
len(response.content),
url,
self._max_html_size_bytes,
)
return _failed_result(url)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
try:
decoded_html = decode_html_bytes(
content,
response.content,
content_type=content_type,
fallback_encoding=response.apparent_encoding or response.encoding,
)

View File

@@ -47,7 +47,6 @@ from onyx.tools.tool_implementations.web_search.utils import (
from onyx.tools.tool_implementations.web_search.utils import MAX_CHARS_PER_URL
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.url import normalize_url as normalize_web_content_url
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
@@ -792,9 +791,7 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
for url in all_urls:
doc_id = url_to_doc_id.get(url)
indexed_section = indexed_by_doc_id.get(doc_id) if doc_id else None
# WebContent.link is normalized (query/fragment stripped). Match on the
# same normalized form to avoid dropping successful crawl results.
crawled_section = crawled_by_url.get(normalize_web_content_url(url))
crawled_section = crawled_by_url.get(url)
if indexed_section and indexed_section.combined_content:
# Prefer indexed

View File

@@ -352,10 +352,17 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
)
if tenant_slack_bot:
bot_token = tenant_slack_bot.bot_token
access_token = (
tenant_slack_bot.user_token or tenant_slack_bot.bot_token
bot_token = (
tenant_slack_bot.bot_token.get_value(apply_mask=False)
if tenant_slack_bot.bot_token
else None
)
user_token = (
tenant_slack_bot.user_token.get_value(apply_mask=False)
if tenant_slack_bot.user_token
else None
)
access_token = user_token or bot_token
except Exception as e:
logger.warning(f"Could not fetch Slack bot tokens: {e}")
@@ -375,8 +382,10 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
None,
)
if slack_oauth_token:
access_token = slack_oauth_token.token
if slack_oauth_token and slack_oauth_token.token:
access_token = slack_oauth_token.token.get_value(
apply_mask=False
)
entities = slack_oauth_token.federated_connector.config or {}
except Exception as e:
logger.warning(f"Could not fetch Slack OAuth token: {e}")
@@ -550,7 +559,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
if override_kwargs.message_history
else []
)
memories = override_kwargs.memories
memories = (
override_kwargs.user_memory_context.as_formatted_list()
if override_kwargs.user_memory_context
else []
)
user_info = override_kwargs.user_info
# Skip query expansion if this is a repeat search call

View File

@@ -1,260 +0,0 @@
from __future__ import annotations
from typing import Any
import requests
from fastapi import HTTPException
from onyx.tools.tool_implementations.web_search.models import (
WebSearchProvider,
)
from onyx.tools.tool_implementations.web_search.models import WebSearchResult
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
BRAVE_WEB_SEARCH_URL = "https://api.search.brave.com/res/v1/web/search"
BRAVE_MAX_RESULTS_PER_REQUEST = 20
BRAVE_SAFESEARCH_OPTIONS = {"off", "moderate", "strict"}
BRAVE_FRESHNESS_OPTIONS = {"pd", "pw", "pm", "py"}
class RetryableBraveSearchError(Exception):
"""Error type used to trigger retry for transient Brave search failures."""
class BraveClient(WebSearchProvider):
def __init__(
self,
api_key: str,
*,
num_results: int = 10,
timeout_seconds: int = 10,
country: str | None = None,
search_lang: str | None = None,
ui_lang: str | None = None,
safesearch: str | None = None,
freshness: str | None = None,
) -> None:
if timeout_seconds <= 0:
raise ValueError("Brave provider config 'timeout_seconds' must be > 0.")
self._headers = {
"Accept": "application/json",
"X-Subscription-Token": api_key,
}
logger.debug(f"Count of results passed to BraveClient: {num_results}")
self._num_results = max(1, min(num_results, BRAVE_MAX_RESULTS_PER_REQUEST))
self._timeout_seconds = timeout_seconds
self._country = _normalize_country(country)
self._search_lang = _normalize_language_code(
search_lang, field_name="search_lang"
)
self._ui_lang = _normalize_language_code(ui_lang, field_name="ui_lang")
self._safesearch = _normalize_option(
safesearch,
field_name="safesearch",
allowed_values=BRAVE_SAFESEARCH_OPTIONS,
)
self._freshness = _normalize_option(
freshness,
field_name="freshness",
allowed_values=BRAVE_FRESHNESS_OPTIONS,
)
def _build_search_params(self, query: str) -> dict[str, str]:
params = {
"q": query,
"count": str(self._num_results),
}
if self._country:
params["country"] = self._country
if self._search_lang:
params["search_lang"] = self._search_lang
if self._ui_lang:
params["ui_lang"] = self._ui_lang
if self._safesearch:
params["safesearch"] = self._safesearch
if self._freshness:
params["freshness"] = self._freshness
return params
@retry_builder(
tries=3,
delay=1,
backoff=2,
exceptions=(RetryableBraveSearchError,),
)
def _search_with_retries(self, query: str) -> list[WebSearchResult]:
params = self._build_search_params(query)
try:
response = requests.get(
BRAVE_WEB_SEARCH_URL,
headers=self._headers,
params=params,
timeout=self._timeout_seconds,
)
except requests.RequestException as exc:
raise RetryableBraveSearchError(
f"Brave search request failed: {exc}"
) from exc
try:
response.raise_for_status()
except requests.HTTPError as exc:
error_msg = _build_error_message(response)
if _is_retryable_status(response.status_code):
raise RetryableBraveSearchError(error_msg) from exc
raise ValueError(error_msg) from exc
data = response.json()
web_results = (data.get("web") or {}).get("results") or []
results: list[WebSearchResult] = []
for result in web_results:
if not isinstance(result, dict):
continue
link = _clean_string(result.get("url"))
if not link:
continue
title = _clean_string(result.get("title"))
description = _clean_string(result.get("description"))
results.append(
WebSearchResult(
title=title,
link=link,
snippet=description,
author=None,
published_date=None,
)
)
return results
def search(self, query: str) -> list[WebSearchResult]:
try:
return self._search_with_retries(query)
except RetryableBraveSearchError as exc:
raise ValueError(str(exc)) from exc
def test_connection(self) -> dict[str, str]:
try:
test_results = self.search("test")
if not test_results or not any(result.link for result in test_results):
raise HTTPException(
status_code=400,
detail="Brave API key validation failed: search returned no results.",
)
except HTTPException:
raise
except (ValueError, requests.RequestException) as e:
error_msg = str(e)
lower = error_msg.lower()
if (
"status 401" in lower
or "status 403" in lower
or "api key" in lower
or "auth" in lower
):
raise HTTPException(
status_code=400,
detail=f"Invalid Brave API key: {error_msg}",
) from e
if "status 429" in lower or "rate limit" in lower:
raise HTTPException(
status_code=400,
detail=f"Brave API rate limit exceeded: {error_msg}",
) from e
raise HTTPException(
status_code=400,
detail=f"Brave API key validation failed: {error_msg}",
) from e
logger.info("Web search provider test succeeded for Brave.")
return {"status": "ok"}
def _build_error_message(response: requests.Response) -> str:
return (
"Brave search failed "
f"(status {response.status_code}): {_extract_error_detail(response)}"
)
def _extract_error_detail(response: requests.Response) -> str:
try:
payload: Any = response.json()
except Exception:
text = response.text.strip()
return text[:200] if text else "No error details"
if isinstance(payload, dict):
error = payload.get("error")
if isinstance(error, dict):
detail = error.get("detail") or error.get("message")
if isinstance(detail, str):
return detail
if isinstance(error, str):
return error
message = payload.get("message")
if isinstance(message, str):
return message
return str(payload)[:200]
def _is_retryable_status(status_code: int) -> bool:
return status_code == 429 or status_code >= 500
def _clean_string(value: Any) -> str:
return value.strip() if isinstance(value, str) else ""
def _normalize_country(country: str | None) -> str | None:
if country is None:
return None
normalized = country.strip().upper()
if not normalized:
return None
if len(normalized) != 2 or not normalized.isalpha():
raise ValueError(
"Brave provider config 'country' must be a 2-letter ISO country code."
)
return normalized
def _normalize_language_code(value: str | None, *, field_name: str) -> str | None:
if value is None:
return None
normalized = value.strip()
if not normalized:
return None
if len(normalized) > 20:
raise ValueError(f"Brave provider config '{field_name}' is too long.")
return normalized
def _normalize_option(
value: str | None,
*,
field_name: str,
allowed_values: set[str],
) -> str | None:
if value is None:
return None
normalized = value.strip().lower()
if not normalized:
return None
if normalized not in allowed_values:
allowed = ", ".join(sorted(allowed_values))
raise ValueError(
f"Brave provider config '{field_name}' must be one of: {allowed}."
)
return normalized

View File

@@ -13,9 +13,6 @@ from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
DEFAULT_MAX_PDF_SIZE_BYTES,
)
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import OnyxWebCrawler
from onyx.tools.tool_implementations.web_search.clients.brave_client import (
BraveClient,
)
from onyx.tools.tool_implementations.web_search.clients.exa_client import (
ExaClient,
)
@@ -38,76 +35,16 @@ from shared_configs.enums import WebSearchProviderType
logger = setup_logger()
def _parse_positive_int_config(
*,
raw_value: str | None,
default: int,
provider_name: str,
config_key: str,
) -> int:
if not raw_value:
return default
try:
value = int(raw_value)
except ValueError as exc:
raise ValueError(
f"{provider_name} provider config '{config_key}' must be an integer."
) from exc
if value <= 0:
raise ValueError(
f"{provider_name} provider config '{config_key}' must be greater than 0."
)
return value
def provider_requires_api_key(provider_type: WebSearchProviderType) -> bool:
"""Return True if the given provider type requires an API key.
This list is most likely just going to contain SEARXNG. The way it works is that it uses public search engines that do not
require an API key. You can also set it up in a way which requires a key but SearXNG itself does not require a key.
"""
return provider_type != WebSearchProviderType.SEARXNG
def build_search_provider_from_config(
provider_type: WebSearchProviderType,
api_key: str | None,
api_key: str,
config: dict[str, str] | None, # TODO use a typed object
) -> WebSearchProvider:
config = config or {}
num_results = int(config.get("num_results") or DEFAULT_MAX_RESULTS)
# SearXNG does not require an API key
if provider_type == WebSearchProviderType.SEARXNG:
searxng_base_url = config.get("searxng_base_url")
if not searxng_base_url:
raise ValueError("Please provide a URL for your private SearXNG instance.")
return SearXNGClient(
searxng_base_url,
num_results=num_results,
)
# All other providers require an API key
if not api_key:
raise ValueError(f"API key is required for {provider_type.value} provider.")
if provider_type == WebSearchProviderType.EXA:
return ExaClient(api_key=api_key, num_results=num_results)
if provider_type == WebSearchProviderType.BRAVE:
return BraveClient(
api_key=api_key,
num_results=num_results,
timeout_seconds=_parse_positive_int_config(
raw_value=config.get("timeout_seconds"),
default=10,
provider_name="Brave",
config_key="timeout_seconds",
),
country=config.get("country"),
search_lang=config.get("search_lang"),
ui_lang=config.get("ui_lang"),
safesearch=config.get("safesearch"),
freshness=config.get("freshness"),
)
if provider_type == WebSearchProviderType.SERPER:
return SerperClient(api_key=api_key, num_results=num_results)
if provider_type == WebSearchProviderType.GOOGLE_PSE:
@@ -127,13 +64,24 @@ def build_search_provider_from_config(
num_results=num_results,
timeout_seconds=int(config.get("timeout_seconds") or 10),
)
raise ValueError(f"Unknown provider type: {provider_type.value}")
if provider_type == WebSearchProviderType.SEARXNG:
searxng_base_url = config.get("searxng_base_url")
if not searxng_base_url:
raise ValueError("Please provide a URL for your private SearXNG instance.")
return SearXNGClient(
searxng_base_url,
num_results=num_results,
)
def _build_search_provider(provider_model: InternetSearchProvider) -> WebSearchProvider:
return build_search_provider_from_config(
provider_type=WebSearchProviderType(provider_model.provider_type),
api_key=provider_model.api_key,
api_key=(
provider_model.api_key.get_value(apply_mask=False)
if provider_model.api_key
else ""
),
config=provider_model.config or {},
)
@@ -185,7 +133,11 @@ def get_default_content_provider() -> WebContentProvider:
if provider_model:
provider = build_content_provider_from_config(
provider_type=WebContentProviderType(provider_model.provider_type),
api_key=provider_model.api_key or "",
api_key=(
provider_model.api_key.get_value(apply_mask=False)
if provider_model.api_key
else ""
),
config=provider_model.config or WebContentProviderConfig(),
)
if provider:

View File

@@ -69,7 +69,11 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
if provider_model is None:
raise RuntimeError("No web search provider configured.")
provider_type = WebSearchProviderType(provider_model.provider_type)
api_key = provider_model.api_key
api_key = (
provider_model.api_key.get_value(apply_mask=False)
if provider_model.api_key
else None
)
config = provider_model.config
# TODO - This should just be enforced at the DB level

View File

@@ -6,6 +6,7 @@ import onyx.tracing.framework._error_tracing as _error_tracing
from onyx.chat.models import ChatMessageSimple
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDocsResponse
from onyx.db.memory import UserMemoryContext
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
from onyx.server.query_and_chat.streaming_models import SectionEnd
@@ -220,7 +221,7 @@ def run_tool_calls(
tools: list[Tool],
# The stuff below is needed for the different individual built-in tools
message_history: list[ChatMessageSimple],
memories: list[str] | None,
user_memory_context: UserMemoryContext | None,
user_info: str | None,
citation_mapping: dict[int, str],
next_citation_num: int,
@@ -252,7 +253,7 @@ def run_tool_calls(
tools: List of available tool instances.
message_history: Chat message history (used to find the most recent user query
for `SearchTool` override kwargs).
memories: User memories, if available (passed through to `SearchTool`).
user_memory_context: User memory context, if available (passed through to `SearchTool`).
user_info: User information string, if available (passed through to `SearchTool`).
citation_mapping: Current citation number to URL mapping. May be updated with
new citations produced by search tools.
@@ -342,7 +343,7 @@ def run_tool_calls(
starting_citation_num=starting_citation_num,
original_query=last_user_message,
message_history=minimal_history,
memories=memories,
user_memory_context=user_memory_context,
user_info=user_info,
skip_query_expansion=skip_search_query_expansion,
)

View File

@@ -1,22 +1,91 @@
from typing import Any
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
def _encrypt_string(input_str: str) -> bytes:
if ENCRYPTION_KEY_SECRET:
logger.warning("MIT version of Onyx does not support encryption of secrets.")
return input_str.encode()
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
def _decrypt_bytes(input_bytes: bytes) -> str:
# No need to double warn. If you wish to learn more about encryption features
# refer to the Onyx EE code
return input_bytes.decode()
def mask_string(sensitive_str: str) -> str:
"""Masks a sensitive string, showing first and last few characters.
If the string is too short to safely mask, returns a fully masked placeholder.
"""
visible_start = 4
visible_end = 4
min_masked_chars = 6
if len(sensitive_str) < visible_start + visible_end + min_masked_chars:
return "••••••••••••"
return f"{sensitive_str[:visible_start]}...{sensitive_str[-visible_end:]}"
MASK_CREDENTIALS_WHITELIST = {
DB_CREDENTIALS_AUTHENTICATION_METHOD,
"wiki_base",
"cloud_name",
"cloud_id",
}
def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, Any]:
masked_creds: dict[str, Any] = {}
for key, val in credential_dict.items():
if isinstance(val, str):
# we want to pass the authentication_method field through so the frontend
# can disambiguate credentials created by different methods
if key in MASK_CREDENTIALS_WHITELIST:
masked_creds[key] = val
else:
masked_creds[key] = mask_string(val)
elif isinstance(val, dict):
masked_creds[key] = mask_credential_dict(val)
elif isinstance(val, list):
masked_creds[key] = _mask_list(val)
elif isinstance(val, (bool, type(None))):
masked_creds[key] = val
elif isinstance(val, (int, float)):
masked_creds[key] = "*****"
else:
masked_creds[key] = "*****"
return masked_creds
def _mask_list(items: list[Any]) -> list[Any]:
masked: list[Any] = []
for item in items:
if isinstance(item, dict):
masked.append(mask_credential_dict(item))
elif isinstance(item, str):
masked.append(mask_string(item))
elif isinstance(item, list):
masked.append(_mask_list(item))
elif isinstance(item, (bool, type(None))):
masked.append(item)
else:
masked.append("*****")
return masked
def encrypt_string_to_bytes(intput_str: str) -> bytes:
versioned_encryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_encrypt_string"

View File

@@ -0,0 +1,205 @@
"""
Wrapper class for sensitive values that require explicit masking decisions.
This module provides a wrapper for encrypted values that forces developers to
make an explicit decision about whether to mask the value when accessing it.
This prevents accidental exposure of sensitive data in API responses.
"""
from __future__ import annotations
import json
from collections.abc import Callable
from typing import Any
from typing import Generic
from typing import NoReturn
from typing import TypeVar
from unittest.mock import MagicMock
from onyx.utils.encryption import mask_credential_dict
from onyx.utils.encryption import mask_string
T = TypeVar("T", str, dict[str, Any])
def make_mock_sensitive_value(value: dict[str, Any] | str | None) -> MagicMock:
"""
Create a mock SensitiveValue for use in tests.
This helper makes it easy to create mock objects that behave like
SensitiveValue for testing code that uses credentials.
Args:
value: The value to return from get_value(). Can be a dict, string, or None.
Returns:
A MagicMock configured to behave like a SensitiveValue.
Example:
>>> mock_credential = MagicMock()
>>> mock_credential.credential_json = make_mock_sensitive_value({"api_key": "secret"})
>>> # Now mock_credential.credential_json.get_value(apply_mask=False) returns {"api_key": "secret"}
"""
if value is None:
return None # type: ignore[return-value]
mock = MagicMock(spec=SensitiveValue)
mock.get_value.return_value = value
mock.__bool__ = lambda self: True # noqa: ARG005
return mock
class SensitiveAccessError(Exception):
"""Raised when attempting to access a SensitiveValue without explicit masking decision."""
class SensitiveValue(Generic[T]):
"""
Wrapper requiring explicit masking decisions for sensitive data.
This class wraps encrypted data and forces callers to make an explicit
decision about whether to mask the value when accessing it. This prevents
accidental exposure of sensitive data.
Usage:
# Get raw value (for internal use like connectors)
raw_value = sensitive.get_value(apply_mask=False)
# Get masked value (for API responses)
masked_value = sensitive.get_value(apply_mask=True)
Raises SensitiveAccessError when:
- Attempting to convert to string via str() or repr()
- Attempting to iterate over the value
- Attempting to subscript the value (e.g., value["key"])
- Attempting to serialize to JSON without explicit get_value()
"""
def __init__(
self,
*,
encrypted_bytes: bytes,
decrypt_fn: Callable[[bytes], str],
is_json: bool = False,
) -> None:
"""
Initialize a SensitiveValue wrapper.
Args:
encrypted_bytes: The encrypted bytes to wrap
decrypt_fn: Function to decrypt bytes to string
is_json: If True, the decrypted value is JSON and will be parsed to dict
"""
self._encrypted_bytes = encrypted_bytes
self._decrypt_fn = decrypt_fn
self._is_json = is_json
# Cache for decrypted value to avoid repeated decryption
self._decrypted_value: T | None = None
def _decrypt(self) -> T:
"""Lazily decrypt and cache the value."""
if self._decrypted_value is None:
decrypted_str = self._decrypt_fn(self._encrypted_bytes)
if self._is_json:
self._decrypted_value = json.loads(decrypted_str)
else:
self._decrypted_value = decrypted_str # type: ignore[assignment]
# The return type should always match T based on is_json flag
return self._decrypted_value # type: ignore[return-value]
def get_value(
self,
*,
apply_mask: bool,
mask_fn: Callable[[T], T] | None = None,
) -> T:
"""
Get the value with explicit masking decision.
Args:
apply_mask: Required. True = return masked value, False = return raw value
mask_fn: Optional custom masking function. Defaults to mask_string for
strings and mask_credential_dict for dicts.
Returns:
The value, either masked or raw depending on apply_mask.
"""
value = self._decrypt()
if not apply_mask:
return value
# Apply masking
if mask_fn is not None:
return mask_fn(value)
# Use default masking based on type
# Type narrowing doesn't work well here due to the generic T,
# but at runtime the types will match
if isinstance(value, dict):
return mask_credential_dict(value)
elif isinstance(value, str):
return mask_string(value)
else:
raise ValueError(f"Cannot mask value of type {type(value)}")
def __bool__(self) -> bool:
"""Allow truthiness checks without exposing the value."""
return True
def __str__(self) -> NoReturn:
"""Prevent accidental string conversion."""
raise SensitiveAccessError(
"Cannot convert SensitiveValue to string. "
"Use .get_value(apply_mask=True/False) to access the value."
)
def __repr__(self) -> str:
"""Prevent accidental repr exposure."""
return "<SensitiveValue: use .get_value(apply_mask=True/False) to access>"
def __iter__(self) -> NoReturn:
"""Prevent iteration over the value."""
raise SensitiveAccessError(
"Cannot iterate over SensitiveValue. "
"Use .get_value(apply_mask=True/False) to access the value."
)
def __getitem__(self, key: Any) -> NoReturn:
"""Prevent subscript access."""
raise SensitiveAccessError(
"Cannot subscript SensitiveValue. "
"Use .get_value(apply_mask=True/False) to access the value."
)
def __eq__(self, other: Any) -> bool:
"""Prevent direct comparison which might expose value."""
if isinstance(other, SensitiveValue):
# Compare encrypted bytes for equality check
return self._encrypted_bytes == other._encrypted_bytes
raise SensitiveAccessError(
"Cannot compare SensitiveValue with non-SensitiveValue. "
"Use .get_value(apply_mask=True/False) to access the value for comparison."
)
def __hash__(self) -> int:
"""Allow hashing based on encrypted bytes."""
return hash(self._encrypted_bytes)
# Prevent JSON serialization
def __json__(self) -> Any:
"""Prevent JSON serialization."""
raise SensitiveAccessError(
"Cannot serialize SensitiveValue to JSON. "
"Use .get_value(apply_mask=True/False) to access the value."
)
# For Pydantic compatibility
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: Any) -> Any:
"""Prevent Pydantic from serializing without explicit get_value()."""
raise SensitiveAccessError(
"Cannot serialize SensitiveValue in Pydantic model. "
"Use .get_value(apply_mask=True/False) to access the value before serialization."
)

Some files were not shown because too many files have changed in this diff Show More